-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathvb_logit_fit_iter.m
109 lines (95 loc) · 3.12 KB
/
vb_logit_fit_iter.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
function [w, V, invV, logdetV] = vb_logit_fit_iter(X, y)
%% [w, V, invV, logdetV] = vb_logit_fit_iter(X, y)
%
% returns parpameters of a fitted logit model
%
% p(y = 1 | x, w) = 1 / (1 + exp(- w' * x)).
%
% The function expects the arguments
% - X: N x D matrix of training input samples, one per row
% - y: N-element column vector of corresponding output {-1, 1} samples
%
% It returns
% - w: posterior weight D-element mean vector
% - V: posterior weight D x D covariance matrix
% - invV, logdetV: inverse of V, and its log-determinant
% - L: variational bound, lower-bounding the log-model evidence p(y | X)
%
% The underlying generative model assumes a weight vector prior
%
% p(w) = N(w | 0, D^-1 I),
%
% where D is the size of x.
%
% The function returns the parameters of the posterior
%
% p(w1 | X, y) = N(w1 | w, V).
%
% Compare to vb_logit_fit[_ard], this function does not use a hyperprior,
% iterates over the inputs separately rather than processing them all at
% once, and is therefore slower, but also computationally more stable as
% it avoids computing the inverse of possibly close-to-singular matrices.
%
% Copyright (c) 2013-2019, Jan Drugowitsch
% All rights reserved.
% See the file LICENSE for licensing information.
[N, D] = size(X);
max_iter = 500;
%% (more or less) uninformative prior
V = eye(D) / D;
invV = eye(D) * D;
logdetV = - D * log(D);
w = zeros(D, 1);
%% iterate over all x separately
for n = 1:N;
xn = X(n,:)';
% precompute values
Vx = V * xn;
VxVx = Vx * Vx';
c = xn' * Vx;
xx = xn * xn';
t_w = invV * w + 0.5 * y(n) * xn;
% start iteration at xi = 0, lam_xi = 1/8
V_xi = V - VxVx / (4 + c);
invV_xi = invV + xx / 4;
logdetV_xi = logdetV - log(1 + c / 4);
w = V_xi * t_w;
L_last = 0.5 * (logdetV_xi + w' * invV_xi * w) - log(2);
for i = 1:max_iter
% update xi by EM algorithm
xi = sqrt(xn' * (V_xi + w * w') * xn);
lam_xi = lam(xi);
% Sherman-Morrison formula and Matrix determinant lemma
V_xi = V - (2 * lam_xi / (1 + 2 * lam_xi * c)) * VxVx;
invV_xi = invV + 2 * lam_xi * xx;
logdetV_xi = logdetV - log(1 + 2 * lam_xi * c);
w = V_xi * t_w;
L = 0.5 * (logdetV_xi + w' * invV_xi * w - xi) ...
- log(1 + exp(- xi)) + lam_xi * xi^2;
% variational bound must grow!
if L_last > L
fprintf('Last bound %6.6f, current bound %6.6f\n', L_last, L);
error('Variational bound should not reduce');
end
% stop if change in variation bound is < 0.001%
if abs(L_last - L) < abs(0.00001 * L)
break
end
L_last = L;
end
if i == max_iter
warning('Bayes:maxIter', ...
'Bayesian logistic regression reached maximum number of iterations.');
end
V = V_xi;
invV = invV_xi;
logdetV = logdetV_xi;
end
function out = lam(xi)
% returns 1 / (4 * xi) * tanh(xi / 2)
divby0_w = warning('query', 'MATLAB:divideByZero');
warning('off', 'MATLAB:divideByZero');
out = tanh(xi ./ 2) ./ (4 .* xi);
warning(divby0_w.state, 'MATLAB:divideByZero');
% fix values where xi = 0
out(isnan(out)) = 1/8;