-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgbnn_mini.m
148 lines (122 loc) · 13 KB
/
gbnn_mini.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
%%% Gripon-Berrou Neural Network, aka thrifty clique codes. Implementation in MatLab, by Vincent Gripon. Coded and tested on Octave 3.6.4 Windows.
% This is a minified version, in the sense that only the core functionalities and comments are preserved so that you can have a quickstart way to experiment with the code. This code is almost standalone, it requires just the shake.m file to reduce the code complexity (even if we could inline it). Please see the full version (gbnn.m) for more options.
clear all;
close all;
% Importing auxiliary functions (shake)
aux = gbnn_aux; % works with both MatLab and Octave
m = 15000 % number of messages
l = 256 % number of character neurons (= range of values allowed per character, eg: 256 for an image in 256 shades of grey per pixel). These neurons will form one cluster (eg: 256 neurons per cluster). NOTE: only changing the number of cluster or the number of messages can change the density since density d = 1 - ( 1 - 1/l^2 )^M
c = 8 % number of clusters (= length of messages = order of the cliques that will be formed, eg: c = 3 means that each clique will at most have 3 nodes)
% NOTE: increasing c or decreasing miterator increase CPU usage and runtime ; increasing l or m increase memory usage.
erasures = 4 % number of characters that may be erased from a message at one iteration. NOTE: must be lower than c!
iterations = 4 % number of iterations to let the network converge
tests = 1 % number of tests
tampered_messages_per_test = 1000; % number of tampered messages to generate and try per test (for maximum speed, set this to the maximum allowed by your memory and decrease the number tests)
tampering_type = 'erase' % 'erease' or 'noise'
gamma_memory = 0 % memory effect: keep a bit of the previous nodes value
residual_memory = 0 % residual memory: previously activated nodes lingers a bit and participate in the next iteration
guiding_mask = false % guide the decoding by focusing only on the clusters where we know the characters may be (by cheating and using the initial message as a guide, but a less efficient guiding mask could be generated by other means). This is only useful if you enable sparse_cliques (Chi > c). TODO: not implemented in the mini version.
threshold = 0 % activation threshold. Nodes having a score below this threshold will be deactivated (ie: the nodes won't propaguce a spike). Unlike the article, this here is just a scalar (a global threshold for all nodes), rather than a vector containing per-cluster thresholds.
n = l * c; % total number of nodes
thriftymessages = logical(sparse(m,n)); % Init and converting to a binary sparse matrix
cnetwork = logical(sparse(n,n)); % init and converting to a binary sparse matrix
if erasures > c
error('Erasures > c which is not possible');
end
tic(); % for perfs
% #### Learning phase
% == Generate input messages
messages = randi([1 l], m, c); % Generate m messages (lines) of length c (columns) with a value between 1 and l. Use randi instead of unidrnd, the result is the same but does not necessitate the Statistics toolbox on MatLab (Octave natively supports it).
% == Convert into thrifty messages
% We convert values between 1 and l into sparse thrifty messages (constant weight code) of length l.
idxs_map = 0:(c-1); % character position index map in the thriftymessages matrix (eg: first character is in the first c numbers, second character in the c numbers after the first c numbers, etc.)
idxs = bsxfun(@plus, messages, l*idxs_map); % use messages matrix directly to compute the indexes (this will compute indexes independently of the row)
offsets = 0:(l*c):(m*l*c);
idxs = bsxfun(@plus, offsets(1:end-1)', idxs); % account for the row number now by generating a vector of index shift per row (eg: [0 lc 2lc 3lc 4lc ...]')
[I, J] = ind2sub([n m], idxs); % convert indexes to two subscripts vector, necessary to optimize sparse set (by using: sparsematrix = sparsematrix + sparse(...) instead of sparsematrix(idxs) = ...)
thriftymessages = sparse(I, J, 1, n, m)'; % store the messages (note that the indexes we now have are columns-oriented but MatLab expects row-oriented indexes, thus we just transpose the matrix)
% == Create network = learn the network
% We simply link all characters inside each message between them as a clique, which will result in an adjacency matrix
if aux.isOctave()
cnetwork = logical(thriftymessages' * thriftymessages); % same as min(cnetwork + thriftymessages'*thriftymessages, 1), but logical allows to use less memory (since values are binary!)
else % MatLab cannot do matrix multiplication on logical matrices...
dthriftymessages = double(thriftymessages);
cnetwork = logical(dthriftymessages' * dthriftymessages);
end
% == Compute density and some practical and theoretical stats
density = full( (sum(sum(cnetwork)) - sum(diag(cnetwork))) / (c*(c-1) * l^2) ) % density = (number_of_links - loops) / max_number_of_links; where max_number_of_links = (l*l*c*(c-1))
% #### Prediction phase
% == Run the test (error correction) and compute error rate (in reconstruction of original message without erasures)
err = 0; % error score
parfor t=1:tests
% -- Generation of a tampered message to remember
% 1- select a list of random messages
% Generate random indices to randomly choose messages
rndidx = randi([1 m], 1, tampered_messages_per_test);
% Fetch the random messages from the generated indices
init = transpose(thriftymessages(rndidx,:));
input = init; % backup the original message before tampering, we will use the original to check if the network correctly corrected the erasure
% 2- randomly tamper them (erasure or noisy bit-flipping of a few characters)
% -- Random erasure of random active characters (which the network is more tolerant than noise, which is normal and described in modern error correction theory)
if strcmpi(tampering_type, 'erase')
% The idea is that we will randomly pick several characters to erase per message (by extracting all nonzeros indices, order per-column/message, and then shuffling them to finally select only a few indices per column to point to the characters we will erase)
[~, idxs] = sort(input, 'descend'); % sort the messages to get the indices of the nonzeros values, but still organized per-column (which find() doesn't provide)
idxs = idxs(1:c, :); % memory efficiency: trim out indices of the zero values (since we are sure that at most a message contains c characters)
idxs = aux.shake(idxs); % per-column shuffle indices! This is how we randomly pick characters.
idxs = idxs(1:erasures, :); % select the number of erasures we want
idxs = bsxfun(@plus, idxs, 0:l*c:l*c*(tampered_messages_per_test-1) ); % offset indices to take account of the column (since sort resets indices count per column)
idxs(idxs == 0) = []; % remove non valid indices (if variable_length, some characters may have less than the number of characters we want to erase) TODO: ensure that a variable_length message keeps at least 2 nodes
input(idxs(1:erasures, :)) = 0; % erase those characters
% -- Random noise (bit-flipping of randomly selected characters)
elseif strcmpi(tampering_type, 'noise')
% The idea is simple: we generate random indices to be "noised" and we bit-flip them using modulo.
idxs = unidrnd([1 c*l], erasures, tampered_messages_per_test); % generate random indices to be tampered
idxs = bsxfun(@plus, idxs, 0:l*Chi:l*Chi*(tampered_messages_per_test-1) ); % offset indices to take account of the column = message (since sort resets indices count per column)
input(idxs) = mod(input(idxs) + 1, 2); % bit-flipping! simply add one to all those entries and modulo one, this will effectively bit-flip them.
end
% -- Prediction step: feed the tampered message to the network and wait for it to converge to a stable state, hopefully the corrected message.
if guiding_mask % if enabled, prepare the guiding mask (the list of clusters that we will keep, all the other nodes from other clusters will be set to 0). This guiding mask can be defined manually if you want, here to do it automatically we compute it from the initial untampered messages, thus we will keep nodes activated only in the clusters where there were activated nodes in the initial message.
gmask = any(reshape(init, l, tampered_messages_per_test * c)); % any is better than sum in our case, and it's also faster and keeps the logical datatype!
end
for iter=1:iterations % To let the network converge towards a stable state...
% 1- Update the network's state: Push message and propagate through the network
if aux.isOctave()
propag = cnetwork * input; % Propagation rule (Sum of Sum): Propagate the previous message state by using a simple matrix propaguct (equivalent to similarity/likelihood? or is it more like a markov chain convergence?). Eg: if cnetwork = [0 1 0 ; 1 0 0], input = [1 1 0] will match quite well: [1 1 0] while input = [0 0 1] will not match at all: [0 0 0]
else % MatLab cannot do matrix multiplication on logical matrices...
propag = double(cnetwork) * double(input);
end
if gamma_memory > 0; propag = propag + (gamma_memory .* input); end; % memory effect: keep a bit of the previous nodes scores
if threshold > 0; propag(propag < threshold) = 0; end; % activation threshold: set to 0 all nodes with a score lower than the activation threshold (ie: those nodes won't emit a spike)
% 2- Winner-takes-all filtering (shutdown all the other nodes, we keep only the most likely characters) - only local WTA is implemented here, see the full non-mini version to use other filtering rules
% The idea is that we will break the clusters and stack them along as a single long cluster spanning several messages, so that we can do a WTA in one pass (with a single max), and then we will unstack them back to their correct places in the messages
propag = reshape(propag, l, tampered_messages_per_test * c); % reshape so that we can do the WTA by a simple column-wise WTA (and it's efficient in MatLab since matrices - and even more with sparse matrices - are stored as column vectors, thus it efficiently use the memory cache since this is the most limiting factor above CPU power). See also: Locality of reference.
winner_value = max(propag); % what is the maximum output value (considering all the nodes in this character)
if ~aux.isOctave()
out = logical(bsxfun(@eq, propag, winner_value)); % Get the winners by filtering out nodes that aren't equal to the max score. Use bsxfun(@eq) instead of ismember.
else
out = logical(sparse(bsxfun(@eq, propag, winner_value))); % Octave does not overload bsxfun with sparse by default, so bsxfun removes the sparsity, thus the propagation at iteration > 1 will be horribly slow! Thus we have to put it back manually. MatLab DOES overload bsxfun with sparsity so this only applies to Octave. See: http://octave.1599824.n4.nabble.com/bsxfun-and-sparse-matrices-in-Matlab-td3867746.html
end
out = and(propag, out); % IMPORTANT NO FALSE WINNER TRICK: if sparse_cliques or variable_length, we deactivate winning nodes that have 0 score (but were chosen because there's no other activated node in this cluster, and max will always return the id of one node! Thus we have to compare with the original propagation matrix and see if the node was really activated before being selected as a winner).
out = reshape(out, l*c, tampered_messages_per_test); % reshape back to get one message per column
% End of WTA
% 3- Some post-processing
if residual_memory > 0; out = out + (residual_memory .* input); end; % residual memory: previously activated nodes lingers a bit and participate in the next iteration
if guiding_mask
out = reshape(out, l, tampered_messages_per_test * c);
if ~aux.isOctave()
out = bsxfun(@and, out, gmask);
else
out = logical(sparse(bsxfun(@and, out, gmask)));
end
out = reshape(out, l*c, tampered_messages_per_test);
end
input = out; % set next messages state as current
end
% -- Test score: compare the corrected message by the network with the original untampered message and check whether it's the same or not (partial or non correction are scored the same: no score)
if residual_memory > 0; input = max(round(input), 0); end; % if residual memory is enabled, we need to make sure that values are binary at the end, not just near-binary (eg: values of nodes with 0.000001 instead of just 0 due to the memory addition), else the error can't be computed correctly since some activated nodes will still linger after the last WTA!
err = err + nnz(sum((init ~= input), 1)); % fast computation of the incorrect messages (one or several errors doesn't matter: if there's at least one error, the message is counted as uncorrected)
end
% Finally, show the error rate and some stats
error_rate = err / (tests * tampered_messages_per_test) % show error rate
total_tampered_messages_tested = tests * tampered_messages_per_test
toc() % print total time elapsed