-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_mgpca_gica.m
135 lines (103 loc) · 3.94 KB
/
run_mgpca_gica.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
function [data1, aux, isi] = run_mgpca_gica(X, S, M, num_pc, num_iter, seed, varargin)
% Multimodal: MGPCA + group ICA + CO
rng(seed);
if ~isempty(varargin)
Y = varargin{1};
A = varargin{2};
end
ut = utils;
% Set Kotz parameters to multivariate laplace
K = size(S{1},1);
eta = ones(K,1);
beta = ones(K,1);
lambda = ones(K,1);
% Use relative gradient
gradtype = 'relative';
% Enable scale control
sc = 1;
% Turn off preprocessing (still removes the mean of the data)
preX = false;
%%
Wr = cell(1,2); % reduced W
[whtM, H] = ut.doMMGPCA(X, num_pc, 'WT');
w0 = ut.stackW({diag(pi/sqrt(3)./std(H,[],2))*eye(size(H,1))});
gica1 = MISAK(w0, 1, {eye(size(H,1))}, {H}, ...
0.5*ones(num_pc,1), ones(num_pc,1), ones(num_pc,1), ...
gradtype, sc, preX);
% the whitening matrix is an identity matrix, different from the whitening matrix from PCA
% sphering turns off PCA
[W1,wht] = icatb_runica(H,'weights',gica1.W{1},'ncomps',size(H,1),'sphering', 'off', 'verbose', 'off', 'posact', 'off', 'bias', 'on');
std_W1 = std(W1*H,[],2); % Ignoring wht because Infomax run with 'sphering' 'off' --> wht = eye(comps)
W1 = diag(pi/sqrt(3) ./ std_W1) * W1;
% RUN GICA using MISA: continuing from Infomax above...
% Could use stochastic optimization, but not doing so because MISA does not implement bias weights (yet)...
% gica1.stochastic_opt('verbose', 'off', 'weights', gica1.W{1}, 'bias', 'off');%, 'block', 1100);
[wout,fval,exitflag,output] = ut.run_MISA(gica1,{W1});
std_gica1_W1 = std(gica1.Y{1},[],2);
gica1.objective(ut.stackW({diag(pi/sqrt(3) ./ std_gica1_W1)*gica1.W{1}})); % update gica1.W{1}
% Combine MISA GICA with whitening matrices to initialize multimodal model
% W = cellfun(@(w) w,whtM,'Un',0);
W = cellfun(@(w) gica1.W{1}*w,whtM,'Un',0);
W = cellfun(@(w,x) diag(pi/sqrt(3) ./ std(w*x,[],2))*w,W,X,'Un',0);
%%
w0_new = ut.stackW(W(M));
data1 = MISAK(w0_new, M, S, X, ...
0.5*beta, eta, [], ...
gradtype, sc, preX);
for mm = M
W0{mm} = [eye(num_pc),zeros(num_pc,0)];
end
w0_short = ut.stackW(W0);
% 1: data1.Y = data1.W * X
% 2: data2.Y = data2.W * data1.Y
% By 1 and 2: data2.Y = data2.W * data1.W * X
data2 = MISAK(w0_short, data1.M, data1.S, data1.Y, ...
0.5*beta, eta, [], ...
gradtype, sc, preX);
data3 = MISAK(w0_short, data1.M, data1.S, data1.Y, ...
0.5*beta, eta, [], ...
gradtype, false, preX); % turn off scale control
% Prep starting point: optimize RE to ensure initial W is in the feasible region
woutW0 = data2.stackW(data2.W);
% Define objective parameters and run optimization
f = @(x) data2.objective(x);
c = [];
barr = 1; % Barrier parameter
m = 1; % Number of past gradients to use for LBFGS-B (m = 1 is equivalent to conjugate gradient)
N = size(X(M(1)),2); % Number of observations
Tol = .5*N*1e-9; % Tolerance for stopping criteria
isi = zeros(1, num_iter+1);
% Set optimization parameters and run
optprob = ut.getop(woutW0, f, c, barr, {'lbfgs' m}, Tol);
[wout,fval,exitflag,output] = fmincon(optprob);
% Prep and run combinatorial optimization
aux = {data2.W; data2.objective(ut.stackW(data2.W)); data3.objective(ut.stackW(data2.W))};
final_W = cell(1,2);
for mm = M
final_W{mm} = data2.W{mm} * W{mm}; % data2.W is 12x12, W is 12x20k
end
data1.objective(ut.stackW(final_W))
if exist('A','var')
isi(1) = data1.MISI(A)
end
for ct = 2:num_iter+1
data2.combinatorial_optim()
optprob = ut.getop(ut.stackW(data2.W), f, c, barr, {'lbfgs' m}, Tol);
[wout,fval,exitflag,output] = fmincon(optprob);
aux(:,ct) = {data2.W; data2.objective_(); data3.objective(ut.stackW(data2.W))};
final_W = cell(1,2);
for mm = M
final_W{mm} = data2.W{mm} * W{mm}; % data2.W is 12x12, data1.W is 12x20k
end
data1.objective(ut.stackW(final_W))
if exist('A','var')
isi(ct) = data1.MISI(A)
end
end
[~, ix] = min([aux{2,:}]);
final_W = cell(1,2);
for mm = M
final_W{mm} = aux{1,ix}{mm} * W{mm}; % data2.W is 12x12, data1.W is 12x20k
end
data1.objective(ut.stackW(final_W));
end