Skip to content

Commit

Permalink
a few updates so that you should be able to use the gradient now
Browse files Browse the repository at this point in the history
  • Loading branch information
mobeets committed Dec 9, 2015
1 parent b63667e commit eb34bdb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 34 deletions.
30 changes: 30 additions & 0 deletions +asd/+gauss/objfcn.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
function [nlogevi, nderlogevi] = objfcn(hyper, Ds, X, Y, XX, XY, YY, p, q, isLog)
if isLog
hyper(2:end) = exp(hyper(2:end));
end

[ro, ssq, deltas] = asd.unpackHyper(hyper);
Reg = asd.prior(ro, Ds, deltas);
[logEvi, sigmaInv, B, isNewBasis] = asd.gauss.logEvidence(X, Y, XX, ...
YY, XY, Reg, ssq, p, q);
nlogevi = -logEvi;
if nargout > 1
if isNewBasis
XY = (X*B)'*Y;
end
mu = tools.postMean(sigmaInv, XY, ssq);
if isNewBasis
mu = B*mu;
sigmaInv = B*sigmaInv*B';
end
sse = tools.sse(Y, X, mu);
Sigma = pinv(sigmaInv);
[der_ro, der_ssq, der_deltas] = asd.gauss.logEvidenceGradient(...
hyper, p, q, Ds, mu, Sigma, Reg, sse);
derlogevi = [der_ro, der_ssq, der_deltas];
nderlogevi = -derlogevi;
% if isLog
% nderlogevi(2:end) = -log(derlogevi(2:end));
% end
end
end
39 changes: 5 additions & 34 deletions +asd/+gauss/optMinNegLogEvi.m
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
willFitIntercept = true;
end
if nargin < 7
nRepeats = 5;
nRepeats = 10;
end
if nargin < 6 || isnan(noDeltaT)
noDeltaT = false;
Expand All @@ -32,7 +32,8 @@
gradObj = false;
end
if nargin < 4 || any(isnan(theta0))
theta0 = pickRandomTheta0(lbs, ubs, isLog);
theta0 = [1e-6 Y'*Y/numel(Y) ones(size(Ds,3),1)]';
% theta0 = pickRandomTheta0(lbs, ubs, isLog);
% theta0 = [1.1529 160.2757 1.0093 1.0048];
% theta0(2:end) = log(theta0(2:end))
end
Expand Down Expand Up @@ -60,7 +61,8 @@
'AlwaysHonorConstraints', 'none'); % interior-point, 'Active-Set'
% 'DerivativeCheck', 'on', ...

obj = @(hyper) objfcn(hyper, Ds, X, Y, XX, XY, YY, p, q, isLog);
obj = @(hyper) asd.gauss.objfcn(hyper, ...
Ds, X, Y, XX, XY, YY, p, q, isLog);
[hyper, fval] = fmincon(obj, theta0, ...
[], [], [], [], lbs, ubs, [], opts);

Expand Down Expand Up @@ -88,37 +90,6 @@
hyper = hyper';
end

function [nlogevi, nderlogevi] = objfcn(hyper, Ds, X, Y, XX, XY, YY, p, q, isLog)
if isLog
hyper(2:end) = exp(hyper(2:end));
end

[ro, ssq, deltas] = asd.unpackHyper(hyper);
Reg = asd.prior(ro, Ds, deltas);
[logEvi, sigmaInv, B, isNewBasis] = asd.gauss.logEvidence(X, Y, XX, ...
YY, XY, Reg, ssq, p, q);
nlogevi = -logEvi;
if nargout > 1
if isNewBasis
XY = (X*B)'*Y;
end
mu = tools.postMean(sigmaInv, XY, ssq);
if isNewBasis
mu = B*mu;
sigmaInv = B*sigmaInv*B';
end
sse = tools.sse(Y, X, mu);
Sigma = sigmaInv \ eye(q);
[der_ro, der_ssq, der_deltas] = asd.gauss.logEvidenceGradient(...
hyper, p, q, Ds, mu, Sigma, Reg, sse);
derlogevi = [der_ro, der_ssq, der_deltas];
nderlogevi = -derlogevi;
if isLog
nderlogevi(2:end) = -log(derlogevi(2:end));
end
end
end

function theta0 = pickRandomTheta0(lbs, ubs, isLog)
if ~isLog
lbs(2:end) = log(lbs(2:end)); ubs(2:end) = log(ubs(2:end));
Expand Down

0 comments on commit eb34bdb

Please sign in to comment.