From 1d6153da8825284800ac4b062df7c670d1dee01a Mon Sep 17 00:00:00 2001 From: jovo Date: Tue, 19 Jul 2011 18:34:57 -0400 Subject: [PATCH] initial commit --- .DS_Store | Bin 0 -> 6148 bytes README | 1 + fast_oopsi.m | 410 ++++++++++++++++++++++++++++++++++ run_oopsi.m | 237 ++++++++++++++++++++ smc_oopsi.m | 117 ++++++++++ smc_oopsi_backward.m | 183 ++++++++++++++++ smc_oopsi_forward.m | 510 +++++++++++++++++++++++++++++++++++++++++++ smc_oopsi_m_step.m | 266 ++++++++++++++++++++++ z1.m | 11 + 9 files changed, 1735 insertions(+) create mode 100644 .DS_Store create mode 100644 README create mode 100755 fast_oopsi.m create mode 100755 run_oopsi.m create mode 100755 smc_oopsi.m create mode 100755 smc_oopsi_backward.m create mode 100755 smc_oopsi_forward.m create mode 100755 smc_oopsi_m_step.m create mode 100755 z1.m diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0= 0} P(n | F) +% which is a MAP estimate for the most likely spike train given the +% fluorescence signal. given the model: +% +% +% \begin{align} +% C_t &= \gamma*C_{t-1} + n_t, \qquad & n_t & \sim \text{Poisson}(n_t; \lamda_t \Delta) +% F_t &= \alpha(C_t + \beta) + \sigma \varepsilon_t, &\varepsilon_t &\sim \mathcal{N}(0,1) +% \end{align} +% +% +% if F_t is a vector, then 'a' is a vector as well +% we approx the Poisson with an Exponential (which means we don't require integer numbers of spikes). +% we take an "interior-point" approach to impose the nonnegative contraint on (*). +% each step is solved in O(T) +% time by utilizing gaussian elimination on the tridiagonal hessian, as +% opposed to the O(T^3) time typically required for non-negative +% deconvolution. +% +% Input---- only F is REQUIRED. the others are optional +% F: fluorescence time series (can be a vector (1 x T) or a matrix (Np x T) +% +% V. structure of algorithm Variables +% Ncells: # of cells within ROI +% T: # of time steps +% Npixels:# of pixels in ROI +% dt: time step size, ie, frame duration, ie, 1/(imaging rate) +% n: if true spike train is known, and we are plotting, plot it (only required is est_a==1) +% h: height of ROI (assumes square ROI) (# of pixels) (only required if est_a==1 and we are plotting) +% w: width of ROI (assumes square ROI) (# of pixels) (only required if est_a==1 and we are plotting) +% +% THE FOLLOWING FIELDS CORRESPOND TO CHOICES THAT THE USER MAKE +% +% fast_poiss: 1 if F_t ~ Poisson, 0 if F_t ~ Gaussian +% fast_nonlin: 1 if F_t is a nonlinear f(C_t), and 0 if F_t is a linear f(C_t) +% fast_plot: 1 to plot results after each pseudo-EM iteration, 0 otherwise +% fast_thr: 1 if thresholding inferred spike train before estiamting {a,b} +% fast_iter_max: max # of iterations of pseudo-EM (1 to use default initial parameters) +% fast_ignore_post: 1 to keep iterating pseudo-EM even if posterior is not increasing, 0 otherwise +% +% THE BELOW FIELDS INDICATE WHETHER ONE WANTS TO ESTIMATE EACH OF THE +% PARAMETERS. IF ANY IS SET TO ZERO, THEN WE DO NOT TRY TO UPDATE THE +% ORIGINAL ESTIMATE, GIVEN EITHER BY THE USER, OR THE INITIAL ESTIMATE +% FROM THE CODE +% +% est_sig: 1 to estimate sig +% est_lam: 1 to estimate lam +% est_gam: 1 to estimate gam +% est_b: 1 to estimate b +% est_a: 1 to estimate a +% +% P. structure of neuron model Parameters +% +% a: spatial filter +% b: background fluorescence +% sig: standard deviation of observation noise +% gam: decayish, ie, tau=dt/(1-gam) +% lam: firing rate-ish, ie, expected # of spikes per frame +% +% Output--- +% n_best: inferred spike train +% P_best: inferred parameter structure +% V: structure of Variables for algorithm to run + +%% initialize algorithm Variables +starttime = cputime; +siz = size(F); if siz(2)==1, F=F'; siz=size(F); end + +% variables determined by the data +if nargin < 2, V = struct; end +if ~isfield(V,'Ncells'), V.Ncells = 1; end % # of cells in image +if ~isfield(V,'T'), V.T = siz(2); end % # of time steps +if ~isfield(V,'Npixels'), V.Npixels = siz(1); end % # of pixels in ROI +if ~isfield(V,'dt'), % frame duration + fr = input('\nwhat was the frame rate for this movie (in Hz)?: '); + V.dt = 1/fr; +end + +% variables determined by the user +if ~isfield(V,'fast_poiss'),V.fast_poiss = 0; end % whether observations are Poisson +if ~isfield(V,'fast_nonlin'), V.fast_nonlin = 0; end +if V.fast_poiss && V.fast_nonlin, + reply = input('\ncan be nonlinear observations and poisson, \ntype 1 for nonlin, 2 for poisson, anything else for neither: '); + if reply==1, V.fast_poiss = 0; V.fast_nonlin = 1; + elseif reply==2, V.fast_poiss = 1; V.fast_nonlin = 0; + else V.fast_poiss = 0; V.fast_nonlin = 0; + end +end +if ~isfield(V,'fast_iter_max'), V.fast_iter_max=1; end % max # of iterations before convergence + +% things that matter if we are iterating to estimate parameters +if V.fast_iter_max>1; + if V.fast_poiss || V.fast_nonlin, + disp('\ncode does not currrently support estimating parameters for \npoisson or nonlinear observations'); + V.fast_iter_max=1; + end + + if ~isfield(V,'fast_plot'), V.fast_plot = 0; end + if V.fast_plot==1 + FigNum = 400; + if V.Npixels>1, figure(FigNum), clf, end % figure showing estimated spatial filter + figure(FigNum+1), clf % figure showing estimated spike trains + if isfield(V,'n'), siz=size(V.n); V.n(V.n==0)=NaN; if siz(1)1 + options = optimset('Display','off'); % don't show warnings for parameter estimation + i = 1; % iteration # + i_best = i; % iteration with highest likelihood + conv = 0; % whether algorithm has converged yet +else + conv = 1; +end + +%% if parameters are unknown, do pseudo-EM iterations +while conv == 0 + if V.fast_plot == 1, MakePlot(n,F,P,V); end % plot results from previous iteration + i = i+1; % update iteratation number + V.fast_iter_tot = i; % record of total # of iterations + P = est_params(n,C,F,P,b); % update parameters based on previous iteration + [n C posts(i)] = est_MAP(F,P); % update inferred spike train based on new parameters + + if posts(i)>post_max || V.fast_ignore_post==1% if this is the best one, keep n and P + n_best = n; % keep n + P_best = P; % keep P + i_best = i; % keep track of which was best + post_max= posts(i); % keep max posterior + end + + % if lik doesn't change much (relatively), or returns to some previous state, stop iterating + if i>=V.fast_iter_max || (abs((posts(i)-posts(i-1))/posts(i))<1e-3 || any(posts(1:i-1)-posts(i))<1e-5)% abs((posts(i)-posts(i-1))/posts(i))<1e-5 || posts(i-1)-posts(i)>1e5; + MakePlot(n,F,P,V); + disp('convergence criteria met') + V.post = posts(1:i); + conv = 1; + end + sound(3*sin(linspace(0,90*pi,2000))) % play sound to indicate iteration is over +end + +V.fast_time = cputime-starttime; % time to run code +V = orderfields(V); % order fields alphabetically to they are easier to read +P_best = orderfields(P_best); +% n_best = n_best./repmat(max(n_best),V.T,1); + +%% fast filter function + function [n C post] = est_MAP(F,P) + + % initialize n and C + z = 1; % weight on barrier function + llam = reshape(1./lam',1,V.Ncells*V.T)'; + if V.fast_nonlin==1 + n = V.gauss_n; + else + n = 0.01+0*llam; % initialize spike train + end + C = 0*n; % initialize calcium + for j=1:V.Ncells + C(j:V.Ncells:end) = filter(1,[1, -P.gam(j)],n(j:V.Ncells:end)); %(1-P.gam(j))*P.b(j); + end + + % precompute parameters required for evaluating and maximizing likelihood + b = repmat(P.b,1,V.T); % for lik + if V.fast_poiss==1 + suma = sum(P.a); % for grad + else + M(d1) = -repmat(P.gam,V.T-1,1); % matrix transforming calcium into spikes, ie n=M*C + ba = P.a'*b; ba=ba(:); % for grad + aa = repmat(diag(P.a'*P.a),V.T,1);% for grad + aF = P.a'*F; aF=aF(:); % for grad + e = 1/(2*P.sig^2); % scale of variance + H1(d0) = 2*e*aa; % for Hess + end + lnprior = llam.*sum(M,2); % for grad + + % find C = argmin_{C_z} lik + prior + barrier_z + while z>1e-13 % this is an arbitrary threshold + + if V.fast_poiss==1 + Fexpect = P.a*(C+b')'; % expected poisson observation rate + lik = -sum(sum(-Fexpect+ F.*log(Fexpect) - gamlnF)); % lik + else + if V.fast_nonlin==1 + S = C./(C+P.k_d); + else + S = C; + end + D = F-P.a*(reshape(S,V.Ncells,V.T))-b; % difference vector to be used in likelihood computation + lik = e*D(:)'*D(:); % lik + end + post = lik + llam'*n - z*sum(log(n)); + s = 1; % step size + d = 1; % direction + while norm(d)>5e-2 && s > 1e-3 % converge for this z (again, these thresholds are arbitrary) + if V.fast_poiss==1 + glik = suma - sumF./(C+b'); + H1(d0) = sumF.*(C+b').^(-2); % lik contribution to Hessian + elseif V.fast_nonlin==1 + glik = -2*P.a*P.k_d*D'.*(C+P.k_d).^-2; + H1diag = (-P.a*P.k_d-2*(C+P.k_d).*D').*((C+P.k_d).^-4); + H1(d0) = H1diag; + else + glik = -2*e*(aF-aa.*C-ba); % gradient + end + g = glik + lnprior - z*M'*(n.^-1); + H2(d0) = n.^-2; % log barrier part of the Hessian + H = H1 + z*(M'*H2*M); % Hessian + d = -H\g; % direction to step using newton-raphson + hit = -n./(M*d); % step within constraint boundaries + hit(hit<0)=[]; % ignore negative hits + if any(hit<1) + s = min(1,0.99*min(hit(hit>0))); + else + s = 1; + end + post1 = post+1; + while post1>=post+1e-7 % make sure newton step doesn't increase objective + C1 = C+s*d; + n = M*C1; + if V.fast_poiss==1 + Fexpect = P.a*(C1+b')'; + lik1 = -sum(sum(-Fexpect+ F.*log(Fexpect) - gamlnF)); + else + if V.fast_nonlin==1 + S1 = C1./(C1+P.k_d); + else + S1 = C1; + end + D = F-P.a*(reshape(S1,V.Ncells,V.T))-b; % difference vector to be used in likelihood computation + lik1 = e*D(:)'*D(:); % lik + end + post1 = lik1 + llam'*n - z*sum(log(n)); + s = s/5; % if step increases objective function, decrease step size + if s<1e-20; disp('reducing s further did not increase likelihood'), break; end % if decreasing step size just doesn't do it + end + C = C1; % update C + post = post1; % update post + end + z=z/10; % reduce z (sequence of z reductions is arbitrary) + end + + % reshape things in the case of multiple neurons within the ROI + n=reshape(n,V.Ncells,V.T)'; + C=reshape(C,V.Ncells,V.T)'; + end + +%% Parameter Update + function P = est_params(n,C,F,P,b) + + % generate regressor for spatial filter + if V.est_a==1 || V.est_b==1 + if V.fast_thr==1 + CC=0*C; + for j=1:V.Ncells + nsort = sort(n(:,j)); + nthr = nsort(round(0.98*V.T)); + nn = Z(1:V.T); + nn(n(:,j)<=nthr)=0; + nn(n(:,j)>nthr)=1; + CC(:,j) = filter(1,[1 -P.gam(j)],nn) + (1-P.gam(j))*P.b(j); + end + else + CC = C; + end + + if V.est_b==1 + A = [CC -1+Z(1:V.T)]; + else + A=CC; + end + X = A\F'; + + P.a = X(1:V.Ncells,:)'; + if V.est_b==1 + P.b = X(end,:)'; + b = repmat(P.b,1,V.T); + end + + D = F-P.a*(reshape(C,V.Ncells,V.T)) - b; + + mse = D(:)'*D(:); + end + + if V.est_a==0 && V.est_b==0 && (V.est_sig==1 || V.est_lam==1), + D = F-P.a*(reshape(C,V.Ncells,V.T)+b); + mse = D(:)'*D(:); + end + + % estimate other parameters + if V.est_sig==1, + P.sig = sqrt(mse)/V.T; + end + if V.est_lam==1, + nnorm = n./repmat(max(n),V.T,1); + if numel(P.lam)==V.Ncells + P.lam = sum(nnorm)'/(V.T*V.dt); + lam = repmat(P.lam,V.T,1)*V.dt; + else + P.lam = nnorm/(V.T*V.dt); + lam = P.lam*V.dt; + end + + end + end + +%% MakePlot + function MakePlot(n,F,P,V) + if V.fast_plot == 1 + if V.Npixels>1 % plot spatial filter + figure(FigNum), nrows=V.Ncells; + for j=1:V.Ncells, subplot(1,nrows,j), + imagesc(reshape(P.a(:,j),V.w,V.h)), + title('a') + end + end + + figure(FigNum+1), ncols=V.Ncells; nrows=3; END=V.T; h=zeros(V.Ncells,2); + for j=1:V.Ncells % plot inferred spike train + h(j,1)=subplot(nrows,ncols,(j-1)*ncols+1); cla + if V.Npixels>1, Ftemp=mean(F); else Ftemp=F; end + plot(z1(Ftemp(2:END))+1), hold on, + bar(z1(n_best(2:END,j))) + title(['best iteration ' num2str(i_best)]), + axis('tight') + set(gca,'XTickLabel',[],'YTickLabel',[]) + + h(j,2)=subplot(nrows,ncols,(j-1)*ncols+2); cla + bar(z1(n(2:END,j))) + if isfield(V,'n'), hold on, + for k=1:V.Ncells + stem(V.n(2:END,k)+k/10,'LineStyle','none','Marker','v','MarkerEdgeColor','k','MarkerFaceColor','k','MarkerSize',2) + end + end + set(gca,'XTickLabel',[],'YTickLabel',[]) + title(['current iteration ' num2str(i)]), + axis('tight') + end + + subplot(nrows,ncols,j*nrows), + plot(1:i,posts(1:i)) % plot record of likelihoods + title(['max lik ' num2str(post_max,4), ', lik ' num2str(posts(i),4)]) + set(gca,'XTick',2:i,'XTickLabel',2:i) + drawnow + end + end +end \ No newline at end of file diff --git a/run_oopsi.m b/run_oopsi.m new file mode 100755 index 0000000..7f657ef --- /dev/null +++ b/run_oopsi.m @@ -0,0 +1,237 @@ +function varargout = run_oopsi(F,V,P) +% this function runs our various oopsi filters, saves the results, and +% plots the inferred spike trains. make sure that fast-oopsi and +% smc-oopsi repository are in your path if you intend to use them. +% +% to use the code, simply provide F, a vector of fluorescence observations, +% for each cell. the fast-oopsi code can handle a matrix input, +% corresponding to a set of pixels, for each time bin. smc-oopsi expects a +% 1D fluorescence trace. +% +% see documentation for fast-oopsi and smc-oopsi to determine how to set +% variables +% +% input +% F: fluorescence from a single neuron +% V: Variables necessary to define for code to function properly (optional) +% P: Parameters of the model (optional) +% +% possible outputs +% fast: fast-oopsi MAP estimate of spike train, argmax_{n\geq 0} P[n|F], (fast.n), parameter estimate (fast.P), and structure of variables for algorithm (fast.V) +% smc: smc-oopsi estimate of {P[X_t|F]}_{t 0 + fprintf('\nfast-oopsi\n') + [fast.n fast.P fast.V]= fast_oopsi(F,V,P); + if V.save, save(V.name_dat,'fast','-append'); end +end + +stupid=1; +% infer spikes using smc-oopsi +if V.smc_iter_max > 0 + fprintf('\nsmc-oopsi\n') + siz=size(F); if siz(1)>1, F=F'; end + if V.fast_iter_max > 0; + if ~isfield(P,'A'), P.A = 50; end % initialize jump in [Ca++] after spike + if ~isfield(P,'n'), P.n = 1; end % Hill coefficient + if ~isfield(P,'k_d'), P.k_d = 200; end % dissociation constant + if ~isfield(V,'T'), V.T = fast.V.T; end % number of time steps + if ~isfield(V,'dt'), V.dt = fast.V.dt; end % frame duration, aka, 1/(framte rate) + + if ~exist('stupid') + bign1=find(fast.n>0.1); + bign0=bign1-1; + df=max((F(bign1)-F(bign0))./(F(bign0))); + + P.C_init= P.C_0; + S0 = Hill_v1(P,P.C_0); + arg = S0 + df*(S0 + 1/13); + P.A = ((arg*P.k_d)./(1-arg)).^(1/P.n)-P.C_0; + end + P.C_0 = 0; + P.tau_c = fast.V.dt/(1-fast.P.gam); % time constant + nnorm = V.n_max*fast.n/max(fast.n); % normalize inferred spike train + C = filter(1,[1 -fast.P.gam],P.A*nnorm)'+P.C_0; % calcium concentration + C1 = [Hill_v1(P,C); ones(1,V.T)]; % for brevity + ab = C1'\F'; % estimate scalse and offset + P.alpha = ab(1); % fluorescence scale + P.beta = ab(2); % fluorescence offset + P.zeta = (mad(F-ab'*C1,1)*1.4785)^2; + P.gamma = P.zeta/5; % signal dependent noise + P.k = V.spikegen.EFGinv(0.01, P, V); + + end + [smc.E smc.P smc.V] = smc_oopsi(F,V,P); + if V.save, save(V.name_dat,'smc','-append'); end +end + +%% provide outputs for later analysis + +if nargout == 1 + if V.fast_iter_max > 0 + varargout{1} = fast; + else + varargout{1} = smc; + end +elseif nargout == 2 + if V.fast_iter_max>0 && V.smc_iter_max>0 + varargout{1} = fast; + varargout{2} = smc; + else + if V.fast_iter_max>0 + varargout{1} = fast; + varargout{2} = V; + else + varargout{1} = smc; + varargout{2} = V; + end + end +elseif nargout == 3 + varargout{1} = fast; + varargout{2} = smc; + varargout{3} = V; +end + +%% plot results + +if V.plot + if isfield(V,'fig_dir'), ffig=V.fig_dir; end + V.name_fig = [ffig V.name]; % filename for figure + fig = figure(3); + clf, + V.T = length(F); + nrows = 1; + if V.fast_iter_max>0, nrows=nrows+1; end + if V.smc_iter_max>0, nrows=nrows+1; end + gray = [.75 .75 .75]; % define gray color + inter = 'tex'; % interpreter for axis labels + xlims = [1 V.T]; % xmin and xmax for current plot + fs = 12; % font size + ms = 5; % marker size for real spike + sw = 2; % spike width + lw = 1; % line width + xticks = 0:1/V.dt:V.T; % XTick positions + skip = round(length(xticks)/5); + xticks = xticks(1:skip:end); + tvec_o = xlims(1):xlims(2); % only plot stuff within xlims + if isfield(V,'true_n'), V.n=V.true_n; end + if isfield(V,'n'), spt=find(V.n); end + + % plot fluorescence data + i=1; h(i)=subplot(nrows,1,i); hold on + plot(tvec_o,z1(F(tvec_o)),'-k','LineWidth',lw); + ylab=ylabel([{'Fluorescence'}],'Interpreter',inter,'FontSize',fs); + set(ylab,'Rotation',0,'HorizontalAlignment','right','verticalalignment','middle') + set(gca,'YTick',[],'YTickLabel',[]) + set(gca,'XTick',xticks,'XTickLabel',[]) + axis([xlims 0 1.1]) + + % plot fast-oopsi output + if V.fast_iter_max>0 + i=i+1; h(i)=subplot(nrows,1,i); hold on, + n_fast=fast.n/max(fast.n); + spts=find(n_fast>1e-3); + stem(spts,n_fast(spts),'Marker','none','LineWidth',sw,'Color','k') + if isfield(V,'n'), + stem(spt,V.n(spt)/max(V.n(spt))+0.1,'Marker','v','MarkerSize',ms,'LineStyle','none','MarkerFaceColor',gray,'MarkerEdgeColor',gray) + end + axis([xlims 0 1.1]) + hold off, + ylab=ylabel([{'fast'}; {'filter'}],'Interpreter',inter,'FontSize',fs); + set(ylab,'Rotation',0,'HorizontalAlignment','right','verticalalignment','middle') + set(gca,'YTick',0:2,'YTickLabel',[]) + set(gca,'XTick',xticks,'XTickLabel',[]) + box off + end + + % plot smc-oopsi output + if V.smc_iter_max>0 + i=i+1; h(i)=subplot(nrows,1,i); hold on, + spts=find(smc.E.nbar>1e-3); + stem(spts,smc.E.nbar(spts),'Marker','none','LineWidth',sw,'Color','k') + if isfield(V,'n'), + stem(spt,V.n(spt)+0.1,'Marker','v','MarkerSize',ms,'LineStyle','none','MarkerFaceColor',gray,'MarkerEdgeColor',gray) + end + axis([xlims 0 1.1]) + hold off, + ylab=ylabel([{'smc'}; {'filter'}],'Interpreter',inter,'FontSize',fs); + set(ylab,'Rotation',0,'HorizontalAlignment','right','verticalalignment','middle') + set(gca,'YTick',0:2,'YTickLabel',[]) + set(gca,'XTick',xticks,'XTickLabel',[]) + box off + end + + % label last subplot + set(gca,'XTick',xticks,'XTickLabel',round(xticks*V.dt*100)/100) + xlabel('Time (sec)','FontSize',fs) + linkaxes(h,'x') + + % print fig + if V.save + wh=[7 3]; %width and height + set(gcf,'PaperSize',wh,'PaperPosition',[0 0 wh],'Color','w'); + print('-depsc',V.name_fig) + print('-dpdf',V.name_fig) + saveas(fig,V.name_fig) + end +end \ No newline at end of file diff --git a/smc_oopsi.m b/smc_oopsi.m new file mode 100755 index 0000000..16a8171 --- /dev/null +++ b/smc_oopsi.m @@ -0,0 +1,117 @@ +function [M P V] = smc_oopsi(F,V,P) +% this function runs the SMC-EM on a fluorescence time-series, and outputs the inferred +% distributions and parameter estimates +% +% Inputs +% F: fluorescence time series +% V: structure of stuff necessary to run smc-em code +% P: structure of initial parameter estimates +% +% Outputs +% M: structure containing mean, variance, and percentiles of inferred distributions +% P: structure containing the final parameter estimates +% V: structure Variables for algorithm to run + +if nargin < 2, V = struct; end +if ~isfield(V,'T'), V.T = length(F); end % # of observations +if ~isfield(V,'freq'), V.freq = 1; end % # time steps between observations +if ~isfield(V,'T_o'), V.T_o = V.T; end % # of observations +if ~isfield(V,'x'), V.x = ones(1,V.T); end % stimulus +if ~isfield(V,'scan'), V.scan = 0; end % epi or scan +if ~isfield(V,'name'), V.name ='oopsi'; end % name for output and figure +if ~isfield(V,'Nparticles'), V.Nparticles = 99; end % # particles +if ~isfield(V,'Nspikehist'), V.Nspikehist = 0; end % # of spike history terms +if ~isfield(V,'CaBaselineDrift'), V.CaBaselineDrift = 0; end %whether to include baseline drift in resting calcium concentration +if V.CaBaselineDrift && V.freq > 1 + warning('baseline drift is not supported for intermittent sampling, using fixed baseline instead'); + V.CaBaselineDrift = 0; +end +if ~isfield(V,'condsamp'), V.condsamp = 1; end % whether to use conditional sampler + +if ~isfield(V,'ignorelik'), V.ignorelik = 1; end % epi or scan +if ~isfield(V,'true_n'), % if true spikes are not available + V.use_true_n = 0; % don't use them for anything +else + V.use_true_n = 1; + V.true_n = repmat(V.true_n',V.Nparticles,1); +end +if ~isfield(V,'smc_iter_max'), % max # of iterations before convergence + reply = str2double(input('\nhow many EM iterations would you like to perform \nto estimate parameters (0 means use default parameters): ', 's')); + V.smc_iter_max = reply; +end +if ~isfield(V,'dt'), + fr = input('what was the frame rate for this movie (in Hz)? '); + V.dt = 1/fr; +end + +% set which parameters to estimate +if ~isfield(V,'est_c'), V.est_c = 1; end % tau_c, A, C_0 +if ~isfield(V,'est_t'), V.est_t = 1; end % tau_c (useful when data is poor) +if ~isfield(V,'est_n'), V.est_n = 1; end % b,k +if ~isfield(V,'est_h'), V.est_h = 0; end % w +if ~isfield(V,'est_F'), V.est_F = 1; end % alpha, beta +if ~isfield(V,'smc_plot'), V.smc_plot = 1; end % plot results with each iteration + +%% initialize model Parameters + +if nargin < 3, P = struct; end +if ~isfield(P,'tau_c'), P.tau_c = 1; end % calcium decay time constant (sec) +if ~isfield(P,'A'), P.A = 50; end % change ins [Ca++] after a spike (\mu M) +if ~isfield(P,'C_0'), P.C_0 = 30; end % baseline [Ca++] (\mu M) +if ~isfield(P,'C_init'),P.C_init= 20; end % initial [Ca++] (\mu M) +if ~isfield(P,'sigma_c'),P.sigma_c= 10; end % standard deviation of noise (\mu M) +if ~isfield(P,'sigma_cr'),P.sigma_cr= 10; end % standard deviation of noise (\mu M) +if ~isfield(P,'n'), P.n = 1; end % hill equation exponent +if ~isfield(P,'k_d'), P.k_d = 200; end % hill coefficient + +if ~isfield(V,'spikegen') + [sg,V] = get_spikegen_info('Bernoulli', P, V); + V.spikegen = sg; +end +if V.freq > 1 && ~strcmpi(V.spikegen.name, 'bernoulli') + warning('only binary spiking is supported with intermittent sampling, switching to bernoulli spike count distributions'); + [sg,V] = get_spikegen_info('Bernoulli', P, V); + V.spikegen = sg; +end + +if ~isfield(V, 'maxspikes'), V.maxspikes = 15; end +if ~isfield(P,'k'), % linear filter + k = str2double(input('approx. how many spikes underly this trace: ', 's')); + P.k = V.spikegen.EFGinv(k / V.T); + %P.k = log(-log(1-k/V.T)/V.dt); +end + +if ~isfield(P,'alpha'), P.alpha = mean(F); end % scale of F +if ~isfield(P,'beta'), P.beta = min(F); end % offset of F +if ~isfield(P,'zeta'), P.zeta = P.alpha/5; end % constant variance +if ~isfield(P,'gamma'), P.gamma = P.zeta/5; end % scaled variance +if V.Nspikehist==1 % if there are spike history terms + if ~isfield(P,'omega'), P.omega = -1; end % weight + if ~isfield(P,'tau_h'), P.tau_h = 0.02; end % time constant + if ~isfield(P,'sigma_h'), P.sigma_h = 0; end % stan dev of noise + if ~isfield(P,'g'), P.g = V.dt/P.tau_h; end % for brevity + if ~isfield(P,'sig2_h'), P.sig2_h = P.sigma_h^2*V.dt; end % for brevity +end +if ~isfield(P,'a'), P.a = V.dt/P.tau_c; end % for brevity +if ~isfield(P,'sig2_c'),P.sig2_c= P.sigma_c^2*V.dt; end % for brevity +if ~isfield(P,'sig2_cr'),P.sig2_cr= P.sigma_cr^2*V.dt; end % for brevity + +%% initialize other stuff +starttime = cputime; +P.lik = -inf; % we are trying to maximize the likelihood here +F = max(F,eps); % in case there are any zeros in the F time series +% P.k = [P.k; 1]; +S = smc_oopsi_forward(F,V,P); % forward step +M = smc_oopsi_backward(S,V,P); % backward step +if V.smc_iter_max>1, P.conv=false; else P.conv=true; end + +while P.conv==false; + P = smc_oopsi_m_step(V,S,M,P,F); % m step + S = smc_oopsi_forward(F,V,P); % forward step + M = smc_oopsi_backward(S,V,P); % backward step +end +fprintf('\n') + +V.smc_iter_tot = length(P.lik); +V.smc_time = cputime-starttime; +V = orderfields(V); \ No newline at end of file diff --git a/smc_oopsi_backward.m b/smc_oopsi_backward.m new file mode 100755 index 0000000..de9a39b --- /dev/null +++ b/smc_oopsi_backward.m @@ -0,0 +1,183 @@ +function M = smc_oopsi_backward(S,V,P) +% this function iterates backward one step computing P[H_t | H_{t+1},O_{0:T}] +% Input--- +% Sim: simulation metadata +% S: particle positions and weights +% P: parameters +% Z: a bunch of stuff initialized for speed +% t: current time step +% +% Output is a single structure Z with the following fields +% n1: vector of spikes or no spike for each particle at time t +% C0: calcium positions at t-1 +% C1: calcium positions at t (technically, this need not be output) +% C1mat:matrix from C1 +% C0mat:matrix from C0 +% w_b: backwards weights + +fprintf('\nbackward step: ') +Z.oney = ones(V.Nparticles,1); % initialize stuff for speed +Z.zeroy = zeros(V.Nparticles); +Z.C0 = S.C(:,V.T); +Z.C0mat = Z.C0(:,Z.oney)'; +Z.constant_spikedistrib = 0; +direct_ind_mat = (1:V.Nparticles)' * ones(1,V.T) + S.n * V.Nparticles; +Z.constant_spikedistrib = 0; +if V.Nspikehist==0 && all(S.p(:) == S.p(1)) + Z.constant_spikedistrib = 1; + ln_Pn_allposs = V.spikegen.logdensity(S.p(1), P, V); %P[n] for all possible values of n_t. column vector. + Z.all_ln_Pn = ln_Pn_allposs(S.n + 1); +end +if V.est_c==false && 0 %temporary hack (-dg), forces full sufficient statistics calc for now. + % if not maximizing the calcium parameters, then the backward step is simple + if V.use_true_n % when spike train is provided, backwards is not necessary + S.w_b=S.w_f; + else + for t=V.T-V.freq-1:-1:V.freq+1 % actually recurse backwards for each time step + Z = step_backward(V,S,P,Z,t); + S.w_b(:,t-1) = Z.w_b; % update forward-backward weights + end + end +else % if maximizing calcium parameters, + % need to compute some sufficient statistics + M.Q = zeros(3); % the quadratic term for the calcium par + M.L = zeros(3,1); % the linear term for the calcium par + M.J = 0; % remaining terms for calcium par + M.K = 0; + for t=V.T-V.freq-1:-1:V.freq+1 + % if V.use_true_n % force true spikes hack + % Z.C0 = S.C(:,t-1); + % Z.C0mat = Z.C0; + % Z.C1 = S.C(:,t); + % Z.C1mat = Z.C1; + % Z.PHH = 1; + % Z.w_b = 1; + % Z.n1 = S.n(t); + % else + + %%THE FOLLOWING WAS PREVIOUSLY THE STEP_BACKWARD SUBFUNCTION + % compute ln P[n_t^i | h_t^i] + Z.n1 = S.n(:,t); % for prettiness sake + if Z.constant_spikedistrib + Z.ln_Pn = Z.all_ln_Pn(:,t); + else + ln_Pn_allposs = V.spikegen.logdensity(S.p(:,t)', P, V)'; %P[n_t | h_t] for all possible values of n_t + Z.ln_Pn = ln_Pn_allposs(direct_ind_mat(:,t)); %P[n_t | h_t] for the n_t we sampled + end + + if V.CaBaselineDrift %compute ln P[Cr_t^i | Cr_{t-1}^j] + Z.Cr0 = S.Cr(:,t-1); + Z.Cr1 = S.Cr(:,t); + Z.Cr1mat = Z.Cr1(:,Z.oney); + Z.Cr0mat = Z.Cr0(:,Z.oney); + Z.ln_PCr_Crn = -0.5 * (Z.Cr1mat - Z.Cr0mat').^2 / P.sig2_cr; + else + Z.ln_PCr_Crn = Z.zeroy; + end + + % compute ln P[C_t^i | C_{t-1}^j, n_t^i] + Z.C0 = S.C(:,t-1); % for prettiness sake + Z.C1 = S.C(:,t); + Z.C1mat = Z.C1(:,Z.oney); % recall from previous time step + Z.C0mat = Z.C0(:,Z.oney); % faster than repmat + if V.CaBaselineDrift + Z.mu = (1-P.a) * S.C(:,t-1) + P.A * Z.n1 + P.a * S.Cr(:,t);% mean + else + Z.mu = (1-P.a) * S.C(:,t-1) + P.A * Z.n1 + P.a * P.C_0;% mean + end + Z.mumat = Z.mu(:,Z.oney)'; % faster than repmat + Z.ln_PC_Cn = -(0.5 / P.sig2_c) * (Z.C1mat - Z.mumat).^2; % P[C_t^i | C_{t-1}^j, n_t^i] + + % compute ln P[h_t^i | h_{t-1}^j, n_{t-1}^i] + Z.ln_Ph_hn = Z.zeroy; % reset transition prob for h terms + for m=1:V.Nspikehist % for each h term + h1 = S.h(:,t,m); % faster than repmat + h1 = h1(:,Z.oney); + h0 = P.g(m)*S.h(:,t-1,m)+S.n(:,t-1); + h0 = h0(:,Z.oney)'; + Z.ln_Ph_hn = Z.ln_Ph_hn - 0.5*(h0 - h1).^2/P.sig2_h(m); + end + + % compute P[H_t^i | H_{t-1}^j] + Z.sum_lns = Z.ln_Pn(:,Z.oney) + Z.ln_PC_Cn + Z.ln_PCr_Crn + Z.ln_Ph_hn; % in order to ensure this product doesn't have numerical errors + Z.mx = max(Z.sum_lns,[],1); % find max in each of row + Z.mx = Z.mx(Z.oney,:); % make a matrix of maxes + Z.T0 = exp(Z.sum_lns-Z.mx); % exponentiate subtracting maxes (so that in each row, the max entry is exp(0)=1 + %Tn = sum(T0,1); % then normalize + %T = T0 .* repmat(1./Tn, V.Nparticles, 1); % such that each column sums to 1 + Z.Tninv = 1 ./ sum(Z.T0,1); % then normalize + Z.T = Z.T0.* Z.Tninv(Z.oney,:); + + % compute P[H_t^i, H_{t-1}^j | O] + Z.PHHn = (Z.T*S.w_f(:,t-1)); % denominator + Z.PHHn(Z.PHHn==0) = eps; + Z.PHHn_inv = 1 ./ Z.PHHn; + Z.PHH = Z.T .* ((Z.PHHn_inv .* S.w_b(:,t)) * S.w_f(:,t-1)'); % normalize such that sum(PHH)=1 + Z.sumPHH = sum(Z.PHH(:)); + if Z.sumPHH==0 + Z.PHH = ones(V.Nparticles)/(V.Nparticles); + else + Z.PHH = Z.PHH / Z.sumPHH; + end + Z.w_b = sum(Z.PHH,1); % marginalize to get P[H_t^i | O] + + if any(isnan(Z.w_b)) + return + end + + if mod(t,100)==0 && t>=9900 + fprintf('\b\b\b\b\b%d',t) + elseif mod(t,100)==0 && t>=900 + fprintf('\b\b\b\b%d',t) + elseif mod(t,100)==0 + fprintf('\b\b\b%d',t) + end + %%THAT WAS PREVIOUSLY THE STEP_BACKWARD SUBFUNCTION + + % end + S.w_b(:,t-1) = Z.w_b; + if V.smc_iter_max > 1 + + % below is code to quickly get sufficient statistics + C0dt = Z.C0*V.dt; + bmat = Z.C1mat-Z.C0mat'; + bPHH = Z.PHH.*bmat; + + M.Q(1,1)= M.Q(1,1) + sum(Z.PHH*(C0dt.^2)); % Q-term in QP + M.Q(1,2)= M.Q(1,2) - Z.n1'*Z.PHH*C0dt; + M.Q(1,3)= M.Q(1,3) - sum(sum(Z.PHH .* Z.C0mat')) * V.dt^2; + M.Q(2,2)= M.Q(2,2) + sum(Z.PHH'*(Z.n1.^2)); + M.Q(2,3)= M.Q(2,3) + sum(Z.n1' * Z.PHH) * V.dt; + %M.Q(2,3)= M.Q(2,3) + sum( sum(Z.PHH(:).*repmat(Z.n1,V.Nparticles,1)) *V.dt); + M.Q(3,3)= M.Q(3,3) + sum(Z.PHH(:))*V.dt^2; + + M.L(1) = M.L(1) + sum(bPHH*C0dt); % L-term in QP + M.L(2) = M.L(2) - sum(bPHH'*Z.n1); + M.L(3) = M.L(3) - V.dt*sum(bPHH(:)); + + M.J = M.J + sum(Z.PHH(:)); % J-term in QP /sum J^(i,j)_{t,t-1}/ + + M.K = M.K + sum(Z.PHH(:).*bmat(:).^2); % K-term in QP /sum J^(i,j)_{t,t-1} (d^(i,j)_t)^2/ + end + end + if V.smc_iter_max > 1 + M.Q(2,1) = M.Q(1,2); % symmetrize Q + M.Q(3,1) = M.Q(1,3); + M.Q(3,2) = M.Q(2,3); + end +end +fprintf('\n') + +% copy particle swarm for later +M.w = S.w_b; +M.n = S.n; +M.C = S.C; +if V.CaBaselineDrift + M.Cr = S.Cr; + M.Crbar = sum(S.w_b.*S.Cr,1); +end +if isfield(S,'h'), M.h=S.h; end +M.nbar = sum(S.w_b.*S.n,1); +M.Cbar = sum(S.w_b.*S.C,1); + +end \ No newline at end of file diff --git a/smc_oopsi_forward.m b/smc_oopsi_forward.m new file mode 100755 index 0000000..8ae942c --- /dev/null +++ b/smc_oopsi_forward.m @@ -0,0 +1,510 @@ +function S = smc_oopsi_forward(F,V,P) +% the function does the backwards sampling particle filter +% notes: this function assumes spike histories are included. to turn them +% off, make sure that V.Nspikehist=0 (M is the # of spike history terms). +% +% The backward sampler has standard variance, as approximated typically. +% Each particle has the SAME backwards sampler, initialized at E[h_t] +% this function only does spike history stuff if necessary +% +% the model is F_t = f(C) = alpha C^n/(C^n + k_d) + beta + e_t, +% where e_t ~ N[0, gamma*f(C)+zeta] +% +% Inputs--- +% F: Fluorescence +% V: Variables for algorithm to run +% P: initial Parameter estimates +% +% Outputs--- +% S: simulation states + +%% allocate memory and initialize stuff + +fprintf('\nT = %g steps',V.T) +fprintf('\nforward step: ') +P.kx = P.k'*V.x; +A.nsmat = ones(V.Nparticles,1) * (0:V.maxspikes); +A.repmat_msp1 = ones(V.maxspikes + 1,1); +IsBernoulli = strcmpi(V.spikegen.name, 'bernoulli'); + +% extize particle info +S.p = zeros(V.Nparticles,V.T); %intialize the parameter describing the spiking distribution. in the case of Bernoulli and Poisson, this will also be the rate +S.n = zeros(V.Nparticles,V.T); % extize spike counts +S.next_n = zeros(V.Nparticles,1); +S.C = P.C_init*ones(V.Nparticles,V.T); % extize calcium +if V.CaBaselineDrift + S.Cr = P.C_init*ones(V.Nparticles,V.T); % extize resting calcium + A.epsilon_cr = randn(V.Nparticles,V.T); % generate noise on Cr +end +S.w_f = 1/V.Nparticles*ones(V.Nparticles,V.T); % extize forward weights +S.w_b = 1/V.Nparticles*ones(V.Nparticles,V.T); % extize forward weights +S.Neff = 1/V.Nparticles*ones(1,V.T_o); % extize N_{eff} + +% preprocess stuff for stratified resampling +ints = linspace(0,1,V.Nparticles+1); % generate intervals +diffs = ints(2)-ints(1); % generate interval size +A.U_resamp = repmat(ints(1:end-1),V.T_o,1)+diffs*rand(V.T_o,V.Nparticles); % resampling matrix + +% sample random variables +A.U_sampl = rand(V.Nparticles,V.T); % random samples +A.epsilon_c = randn(V.Nparticles,V.T); % generate noise on C + +A.constant_spikedistrib = 0; +if V.Nspikehist>0 % if spike histories + S.h = zeros(V.Nparticles,V.T,V.Nspikehist); % extize spike history terms + A.epsilon_h = zeros(V.Nparticles, V.T, V.Nspikehist); % generate noise on h + for m=1:V.Nspikehist % add noise to each h + A.epsilon_h(:,:,m) = sqrt(P.sig2_h(m))*randn(V.Nparticles,V.T); + end +else % if not, comput P[n_t] for all t + S.p = repmat(V.spikegen.nonlinearity(P.kx, P, V)',1,V.Nparticles)'; + if V.freq == 1 && all(S.p(:) == S.p(1)) + A.constant_spikedistrib = 1; + V.common_ln_n = V.spikegen.logdensity(S.p(:,1)', P, V)'; % compute log(P[n spikes], n = 0,1,2 ... each row is a particle, each column is a spike count + end +end + +% extize stuff needed for conditional sampling +A.n_sampl = rand(V.Nparticles,V.T); % generate random number to use for sampling n_t +A.C_sampl = rand(V.Nparticles,V.T); % generate random number to use for sampling C_t +A.oney = ones(V.Nparticles,1); % vector of ones to call for various purposes to speed things up +A.zeroy = zeros(V.Nparticles,1); % vector of zeros + +% extize stuff needed for REAL backwards sampling +O.p_o = zeros(2^(V.freq-1),V.freq); % extize backwards prob +mu_o = zeros(2^(V.freq-1),V.freq); % extize backwards mean +sig2_o = zeros(1,V.freq); % extize backwards variance + +% extize stuff needed for APPROX backwards sampling +O.p = zeros(V.freq,V.freq); % extize backwards prob +O.mu = zeros(V.freq,V.freq); % extize backwards mean +O.sig2 = zeros(V.freq,V.freq); % extize backwards var + +% initialize backwards distributions +s = V.freq; % initialize time of next observation +O.p_o(1,s) = 1; % initialize P[F_s | C_s] + +[mu_o(1,s) sig2_o(s)] = init_lik(P,F(s)); + +O.p(1,s) = 1; % initialize P[F_s | C_s] +O.mu(1,s) = mu_o(1,s); % initialize mean of P[O_s | C_s] +O.sig2(1,s) = sig2_o(s); % initialize var of P[O_s | C_s] + +if V.freq>1 % if intermitent sampling + for tt=s:-1:2 % generate spike binary matrix + A.spikemat(:,tt-1) = repmat([repmat(0,1,2^(s-tt)) repmat(1,1,2^(s-tt))],1,2^(tt-2))'; + end + nspikes = sum(A.spikemat')'; % count number of spikes at each time step + + for n=0:V.freq-1 + A.ninds{n+1}= find(nspikes==n); % get index for each number of spikes + A.lenn(n+1) = length(A.ninds{n+1}); % find how many spikes + end +end + +O = update_moments(V,F,P,S,O,A,s); % recurse back to get P[O_s | C_s] before the first observation + +%% do the particle filter + +for t=V.freq+1:V.T-V.freq + if IsBernoulli && V.freq > 1 && ~V.CaBaselineDrift && V.condsamp + S = cond_sampler_intermittent_bernoulli_nodrift(V,F,P,S,O,A,t,s); + elseif V.condsamp==0 || (F(s+V.freq)-P.beta)/P.alpha>0.98 || (F(s+V.freq)-P.beta)/P.alpha<0 % prior sampler with drift when gaussian approximation is not good + S = prior_sampler(V,F,P,S,A,t); % use prior sampler when gaussian approximation is not good + else % otherwise use conditional sampler + S = cond_sampler_nonintermittent(V,F,P,S,O,A,t,s); + end + + S.C(:,t)=S.next_C; % done to speed things up for older + if V.CaBaselineDrift + S.Cr(:,t) = S.next_Cr; + end + S.n(:,t)=S.next_n; % matlab, having issues with from-function + if(isfield(S,'next_w_f')) % update calls to large structures + S.w_f(:,t)=S.next_w_f; + else + S.w_f(:,t)=S.w_f(:,t-1); + end + + if V.Nspikehist>0 % update S.h & S.p + for m=1:V.Nspikehist, S.h(:,t,m)=S.h_new(:,1,m); end + S.p(:,t)=S.p_new; + end + + % at observations + if mod(t,V.freq)==0 + % STRAT_RESAMPLE -- THIS WAS CAUSING TROUBLE IN OLDER MATLAB WITH + % MANAGING LARGE ARRAYS IN S-STRUCTURE WITHIN THE FUNCTION CALL + % S = strat_resample(V,S,t,A.U_resamp); % stratified resample + + Nresamp=t/V.freq; % increase sample counter + S.Neff(Nresamp) = 1/sum(S.w_f(:,t).^2); % store N_{eff} + + % if weights are degenerate or we are doing prior sampling then resample + if S.Neff(Nresamp) < V.Nparticles/2 || V.condsamp==0 + [foo,ind] = histc(A.U_resamp(Nresamp,:),[0 cumsum(S.w_f(:,t))']); + [ri,ri] = sort(rand(V.Nparticles,1)); % these 3 lines stratified resample + ind = ind(ri); + + S.p(:,t-V.freq+1:t) = S.p(ind,t-V.freq+1:t); % resample probabilities (necessary?) (yes, i think we need these for the backward step -- dg) + S.n(:,t-V.freq+1:t) = S.n(ind,t-V.freq+1:t); % resample spikes + S.C(:,t-V.freq+1:t) = S.C(ind,t-V.freq+1:t); % resample calcium + if V.CaBaselineDrift + S.Cr(:,t-V.freq+1:t) = S.Cr(ind,t-V.freq+1:t); % resample resting calcium + end + S.w_f(:,t-V.freq+1:t) = 1/V.Nparticles*ones(V.Nparticles,V.freq); % reset weights + if V.Nspikehist>0 % if spike history terms + S.h(:,t-V.freq+1:t,:) = S.h(ind,t-V.freq+1:t,:);% resample all h's + end + end %function + + O = update_moments(V,F,P,S,O,A,t); % estimate P[O_s | C_tt] for all t'0 && imag(mu1)==0 + mu_nm1power = mu1 .^ (P.n - 1); + gprime = P.alpha * P.n * mu_nm1power * P.k_d ./ ((mu1 .* mu_nm1power + P.k_d).^2); + + SCa_at_mu1 = (F - P.beta) / P.alpha; + sig1 = (P.gamma * SCa_at_mu1 + P.zeta) ./ (gprime^2); +else + mu1=0; + sig1=0; +end + +end %init_lik + +%% update moments + +function O = update_moments(V,F,P,S,O,A,t) +%%%% maybe make a better proposal for epi + +s = V.freq; % find next observation time + +[mu1 sig1] = init_lik(P,F(t+s)); + +O.p(1,s) = 1; % initialize P[F_s | C_s] +O.mu(1,s) = mu1; % initialize mean of P[O_s | C_s] +O.sig2(1,s) = sig1; % initialize var of P[O_s | C_s] + +if s > 1 + if V.CaBaselineDrift + error('intermittent sampling + baseline drift is not yet implemented'); + end + if ~strcmpi(V.spikegen.name, 'bernoulli') + error('intermittent sampling is only implemented for bernoulli spiking'); + end + + mu_o(1,s) = mu1; % initialize mean of P[O_s | C_s] + sig2_o(s) = sig1; % initialize var of P[O_s | C_s] + + + if V.Nspikehist>0 + hhat = zeros(V.freq,V.Nspikehist); % extize hhat + phat = zeros(1,V.freq+1); % extize phat + + hs = S.h(:,t,:); % this is required for matlab to handle a m-by-n-by-p matrix + h(:,1:V.Nspikehist)= hs(:,1,1:V.Nspikehist); % this too + hhat(1,:) = sum(repmat(S.w_f(:,t),1,V.Nspikehist).*h,1); % initialize hhat + phat(1) = sum(S.w_f(:,t).*S.p(:,t),1); % initialize phat + end + + if V.Nspikehist>0 + for tt=1:s + % update hhat + for m=1:V.Nspikehist % for each spike history term + hhat(tt+1,m)=(1-P.g(m))*hhat(tt,m)+phat(tt); + end + y_t = P.kx(tt+t)+P.omega'*hhat(tt+1,:)'; % input to neuron + phat(tt+1) = 1-exp(-exp(y_t)*V.dt); % update phat + end + else + phat = 1-exp(-exp(P.kx(t+1:t+s)')*V.dt); % update phat + end + + for tt=s:-1:2 + + O.p_o(1:2^(s-tt+1),tt-1) = repmat(O.p_o(1:2^(s-tt),tt),2,1).*[(1-phat(tt))*ones(1,2^(s-tt)) phat(tt)*ones(1,2^(s-tt))]'; + mu_o(1:2^(s-tt+1),tt-1) = (1-P.a)^(-1)*(repmat(mu_o(1:2^(s-tt),tt),2,1)-P.A*A.spikemat(1:2^(s-tt+1),tt-1)-P.a*P.C_0); %mean of P[O_s | C_k] + sig2_o(tt-1) = (1-P.a)^(-2)*(P.sig2_c+sig2_o(tt)); % var of P[O_s | C_k] + + for n=0:s-tt+1 + nind=A.ninds{n+1}; + O.p(n+1,tt-1) = sum(O.p_o(nind,tt-1)); + ps = (O.p_o(nind,tt-1)/O.p(n+1,tt-1))'; + O.mu(n+1,tt-1) = ps*mu_o(nind,tt-1); + O.sig2(n+1,tt-1)= sig2_o(tt-1) + ps*(mu_o(nind,tt-1)-repmat(O.mu(n+1,tt-1)',A.lenn(n+1),1)).^2; + end + end + + if s==2 + O.p = O.p_o; + O.mu = mu_o; + O.sig2 = repmat(sig2_o,2,1); + O.sig2(2,2) = 0; + end + + while any(isnan(O.mu(:))) % in case ps=0/0, which yields a NaN, approximate mu and sig + O.mu(1,:) = mu_o(1,:); + O.sig2(1,:) = sig2_o(1,:); + ind = find(isnan(O.mu)); + O.mu(ind) = O.mu(ind-1)-P.A; + O.sig2(ind) = O.sig2(ind-1); + end +end +O.p=O.p+eps; % such that there are no actual zeros +end %function UpdateMoments + +%% particle filtering using the prior sampler +function S = prior_sampler(V,F,P,S,A,t) +S = calc_next_spikedist_params(V,P,S,A,t); +if V.use_true_n + S.next_n = V.true_n(:,t); +else + den = V.spikegen.density(S.p_new', P, V); + for pind = 1:V.Nparticles + cden = cumsum([0 den(:,pind)']); + cden(end) = 1 + eps; + S.next_n(pind,1) = find(A.U_sampl(pind,t) > cden,1,'last') - 1; + end +end +if V.CaBaselineDrift + S.next_Cr = S.Cr(:,t-1) + sqrt(P.sig2_cr) * A.epsilon_cr(:,t); + S.next_C = (1-P.a) * S.C(:,t-1) + P.A * S.next_n + P.a * S.Cr(:,t) + sqrt(P.sig2_c) * A.epsilon_c(:,t);% sample C +else + S.next_C = (1-P.a) * S.C(:,t-1) + P.A *S.next_n + P.a*P.C_0 + sqrt(P.sig2_c) * A.epsilon_c(:,t);% sample C +end +% get weights at every observation %THIS NEEDS FIX FOR EPI DATA +if mod(t,V.freq)==0 + S_mu = Hill_v1(P,S.next_C); + F_mu = P.alpha*S_mu+P.beta; % compute E[F_t] + F_var = P.gamma*S_mu+P.zeta; % compute V[F_t] + %%%% this must also change for epi + ln_w = -0.5*(F(t)-F_mu).^2./F_var - log(F_var)/2;% compute log of weights + ln_w = ln_w-max(ln_w); % subtract the max to avoid rounding errors + w = exp(ln_w); % exponentiate to get actual weights + % error('forgot to include the previous weight in this code!!!!') + % break + S.next_w_f = w/sum(w); % normalize to define a legitimate distribution +end + +end + +function S = calc_next_spikedist_params(V,P,S,A,t) % if spike histories, sample h and update p +if V.Nspikehist>0 + S.h_new=zeros(size(S.n,1),1,V.Nspikehist); + for m=1:V.Nspikehist + S.h_new(:,1,m)=(1-P.g(m))*S.h(:,t-1,m)+S.n(:,t-1)+A.epsilon_h(:,t,m); + end + % update rate and sample spikes + hs = S.h_new; % this is required for matlab to handle a m-by-n-by-p matrix + h(:,1:V.Nspikehist) = hs(:,1,1:V.Nspikehist); % this too + y_t = P.kx(t)+P.omega'*h'; % input to neuron + S.p_new = V.spikegen.nonlinearity(y_t, P, V); + %S.p_new = 1-exp(-exp(y_t)*V.dt);% update rate for those particles with y_t<0 + S.p_new = S.p_new(:); +else + S.p_new = S.p(:,t); +end +end + +function S = cond_sampler_intermittent_bernoulli_nodrift(V,F,P,S,O,A,t,s) +S = calc_next_spikedist_params(V,P,S,A,t); +k = V.freq-(t-s)+1; +A.repmat_ktimes_mat = ones(k,1); +m2 = O.mu(1:k,t-s); % mean of P[O_s | C_k] for n_k=1 and n_k=0 +v2 = O.sig2(1:k,t-s); % var of P[O_s | C_k] for n_k=1 and n_k=0 +prev_C = S.C(:,t-1); + +% compute P[n_k | h_k] +ln_n = [log(S.p_new) log(1-S.p_new)]; % compute [log(spike) log(no spike)] + +m0 = (1-P.a) * prev_C + P.a * P.C_0; % mean of P[C_k | C_{t-1}, n_k=0] +m1 = (1-P.a) * prev_C + P.a * P.C_0 + P.A ; % mean of P[C_k | C_{t-1}, n_k=1] + +v = P.sig2_c + v2; +v = v(:,A.oney)'; % var of G_n(n_k | O_s) for n_k=1 and n_k=0 + +%is the -0.5*log(2*pi.*v) term necessary ??? +ln_G0 = -0.5*log(2*pi * v) - 0.5 * (m0(:,A.repmat_ktimes_mat) - m2(:,A.oney)').^2 ./ v; +ln_G1 = -0.5*log(2*pi * v) - 0.5 * (m1(:,A.repmat_ktimes_mat) - m2(:,A.oney)').^2 ./ v; + +mx = max(max(ln_G0,[],2),max(ln_G1,[],2));% get max of these +mx = mx(:,A.repmat_ktimes_mat); + +%why calculate exp when we're soon going to multiply by something and then take logs again ??? +G1 = exp(ln_G1-mx); % norm dist'n for n=1; +M1 = G1*O.p(1:k,t-s); % times prob of n=1 + +G0 = exp(ln_G0-mx); % norm dist'n for n=0; +M0 = G0*O.p(1:k,t-s); % times prob n=0 + +ln_G = [log(M1) log(M0)]; % ok, now we actually have the gaussians + +% compute q(n_k | h_k, O_s) +ln_q_n = ln_n + ln_G; % log of sampling dist +mx = max(ln_q_n,[],2); % find max of each row +mx2 = [mx mx]; % matricize +q_n = exp(ln_q_n-mx2); % subtract max to ensure that for each row, there is at least one positive probability, and exponentiate +sq_n = sum(q_n,2); +q_n = q_n ./ [sq_n sq_n]; % normalize to make this a true sampling distribution (ie, sums to 1) + +% sample n +S.next_n= A.n_sampl(:,t) cden_mat,2) - 1; +% ---- STEP 2: Sample Cr if baseline is allowed to drift ---- +if V.CaBaselineDrift %sample Cr + mFCrn = (mCF - prev_C * (1 - P.a) - P.A * S.next_n) / P.a; %mean of P[F | Cr_t, n_t = S.next_n] + v_cr = (1 ./ vFCr + 1 / P.sig2_cr).^(-1); % variance of distribution for sampling Cr + m_cr = v_cr .* (prev_Cr / P.sig2_cr + mFCrn ./ vFCr); %mean of distribution for sampling Cr + S.next_Cr = m_cr + sqrt(v_cr) * A.epsilon_cr(:,t); + C_baseline = S.next_Cr; +else + C_baseline = P.C_0; +end +% ---- STEP 3: Sample C ---- +v_c = 1/(1/vCF + 1/P.sig2_c); %variance of dist'n for sampling C +m_givenspiking = (1-P.a) * S.C(:,t-1) + P.A * S.next_n + P.a * C_baseline; %expected calcium given spiking and previous calcium value +m_c = v_c * (mCF / vCF + m_givenspiking / P.sig2_c);%mean of dist'n for sampling C +S.next_C = m_c + sqrt(v_c) * A.epsilon_c(:,t); +% ---- STEP 4: update weights ---- +S_mu = Hill_v1(P,S.next_C); +F_mu = P.alpha*S_mu+P.beta; % compute E[F_t | C_t] +F_var = P.gamma*S_mu+P.zeta; % compute V[F_t | C_t] +log_PF_H = -0.5*(F(t)-F_mu).^2./F_var - log(F_var)/2; % compute log of P[F | H] !!!MUST CHANGE FOR EPI!!! + +direct_ind = (1:V.Nparticles)' + S.next_n * V.Nparticles; %index to access ln_n(pind, S.next_n(pind + 1) for pind = 1:V.nparticles +log_n = ln_n(direct_ind); +log_C_Cn = -0.5*(S.next_C-((1-P.a)*S.C(:,t-1)+P.A*S.next_n+P.a* C_baseline)).^2/P.sig2_c;%log P[C_k | C_{t-1}, n_k] + +log_q_n = log(q_n(direct_ind)); % compute what was the log prob of sampling the number of spikes we sampled, for each particle. +log_q_C = -0.5*(S.next_C - m_c).^2./v_c;% log prob of sampling the C_k that was sampled +if V.CaBaselineDrift + log_Cr_Cr = -0.5*(S.next_Cr - S.Cr(:,t-1)).^2 / P.sig2_cr; + log_q_Cr = -0.5*(S.next_Cr - m_cr).^2./v_cr ;% log prob of sampling the Cr_k that was sampled +else + log_Cr_Cr = 0; + log_q_Cr = 0; +end +log_quotient = log_PF_H + log_n + log_C_Cn + log_Cr_Cr - log_q_n - log_q_C - log_q_Cr; %note that terms from baseline drift cancel +sum_logs = log_quotient+log(S.w_f(:,t-1)); % update log(weights) +w = exp(sum_logs-max(sum_logs)); % exponentiate log(weights) +S.next_w_f = w./sum(w); % normalize such that they sum to unity +if any(isnan(w)) %, Fs=1024; ts=0:1/Fs:1; sound(sin(2*pi*ts*200)), + warning('smc:weights','some weights are NaN') + keyboard; +end +end \ No newline at end of file diff --git a/smc_oopsi_m_step.m b/smc_oopsi_m_step.m new file mode 100755 index 0000000..dc45d84 --- /dev/null +++ b/smc_oopsi_m_step.m @@ -0,0 +1,266 @@ +function P = smc_oopsi_m_step(V,S,M,P,F) +% this function finds the mle of the parameters +% +% Input--- +% V: simulation parameters +% R: real data +% S: simulation results +% M: moments and sufficient stats +% P: old parameter estimates +% +% Output is 'Enew', a structure with a field for each parameter, plus some +% additional fields for various likelihoods + +Eold= P; % store most recent parameter structure +P = update_params(V,S,M,P,F); % update parameters +i = length(P.lik); +fprintf('\n\nIteration #%g, lik=%g, dlik=%g\n',i,P.lik(end),P.lik(end)-Eold.lik(end)) + +% when estimating calcium parameters, display param estimates and lik +if V.est_c==1 + dtheta = norm([P.tau_c; P.A; P.C_0]-... + [Eold.tau_c; Eold.A; Eold.C_0])/norm([Eold.tau_c; Eold.A; Eold.C_0; P.sigma_c]); + fprintf('\ndtheta = %.2f',dtheta); + fprintf('\ntau = %.2f',P.tau_c) + fprintf('\nA = %.2f',P.A) + fprintf('\nC_0 = %.2f',P.C_0) + fprintf('\nsig = %.2f',P.sigma_c) + fprintf('\nalpha = %.2f',P.alpha) + fprintf('\nbeta = %.2f',P.beta) + fprintf('\ngamma = %.2g\n',P.gamma) +end +if V.est_n == true + fprintf('\nk = %.2f\n',P.k) +end + +% plot lik and inferrence +if V.smc_plot + figure(100) + subplot(4,1,1), hold off, plot(P.lik,'o'), axis('tight') + subplot(4,1,2), plot(F,'k'), hold on, + plot(P.alpha*Hill_v1(P,sum(S.w_b.*S.C,1))+P.beta,'b'), hold off, axis('tight') + + % plot spike train estimate + subplot(4,1,3), cla, hold on, + if isfield(V,'n'), + stem(V.n,'Marker','.','MarkerSize',20,'LineWidth',2,... + 'Color',[.75 .75 .75],'MarkerFaceColor','k','MarkerEdgeColor','k'); + axis('tight'), + end + stem(M.nbar,'Marker','none','LineWidth',2,'Color',[0 .5 0]) + ylabel('current n') + axis([0 V.T 0 1]), + + % plot "best" spike train estimate + subplot(4,1,4), cla,hold on, + if isfield(V,'n'), stem(V.n,'Marker','.',... + 'MarkerSize',20,'LineWidth',2,'Color',[.75 .75 .75]); end + stem(M.nbar,'Marker','none','LineWidth',2,'Color',[0 .5 0]); + ylabel('best n'); + axis([0 V.T 0 1]); + drawnow + %sound(3*sin(linspace(0,90*pi,2000))) % play sound to indicate iteration is over +end + +if i>=V.smc_iter_max + P.conv=true; +end + + function Enew = update_params(V,S,M,E,F) + Enew = E; % initialize parameters + lik = []; % initialize likelihood + + optionsQP = optimset('Display','off'); + optionsGLM = optimset('Display','off','GradObj','off','TolFun',1e-6,'LargeScale','off'); + + %% MLE for spiking parameters + if V.est_n == true + % MLE for spike rate parameters: baseline (b), linear filter (k), and spike history weights (omega) + fprintf('\nestimating spike rate params\n') + RateParams=E.k; % vector of parameters to estimate (changes depending on user input of which parameters to estimate) + sp = S.n==1; % find (particles,time step) pairs that spike + nosp = S.n==0; % don't spike + % x = repmat(V.x,1,V.Nparticles); % generate matrix for gradinent + zeroy = zeros(V.Nparticles,V.T); % make matrix of zeros for evaluating lik + + efunc = @(z) -V.spikegen.LLdata_for_spikeparams(z, S, P, V); + if V.Nspikehist == 0 && ismember(lower(V.spikegen.name), {'bernoulli','poisson'}) + Enew.k = mean(sum(S.n.*S.w_b)); + Enew.lik_r = -efunc(Enew.k); + else + + if V.est_h == true && V.Nspikehist > 0 + z0 = [E.k E.omega]; + else + z0 = E.k; + end + [best_z, neglik] = fminunc(efunc, z0, optionsGLM); + Enew.k = best_z(1:end-V.Nspikehist); + if V.est_h == true && V.Nspikehist > 0 + Enew.omega = best_z(end-V.Nspikehist + 1:end); + end + Enew.lik_r = -neglik; + end + +% old version has been commented out. was bernoulli-only and may have contained bugs --dg +% if V.est_h == true %whether or not to estimate the filter denoted by "w" in Vogelstein et. al 2009 +% +% if V.Nspikehist>0 % if spike history terms are present +% RateParams=[RateParams; E.omega]; % also estimate omega +% % for i=1:V.Nspikehist % and modify stimulus matrix for gradient +% % x(V.StimDim+i,:)=reshape(S.h(:,:,i),1,V.Nparticles*V.T); +% % end +% end +% %[bko neglik] = fminunc(@f_bko,RateParams,optionsGLM);% find MLE +% Z=ones(size(RateParams)); +% [bko neglik]=fmincon(@f_bko,RateParams,[],[],[],[],-10*Z,10*Z,[],optionsGLM);%fix for h-problem +% Enew.k = bko(1:end-V.Nspikehist); % set new parameter estimes +% if V.Nspikehist>0, Enew.omega = bko(end-V.Nspikehist+1:end); end % for omega too +% else +% for i=1:V.Nspikehist % and modify stimulus matrix for gradient +% x(V.StimDim+i,:)=reshape(S.h(:,:,i),1,V.Nparticles*V.T); +% end +% [bk neglik] = fminunc(@f_bk,RateParams,optionsGLM); % find MLE +% Enew.k = bk(1:end); % set new parameter estimes +% end +% Enew.lik_r = -neglik; + + lik = [lik Enew.lik_r]; + end + + function [lik dlik]= f_bko(RateParams) % get lik and grad + + xk = RateParams(1:end-V.Nspikehist)'*V.x; % filtered stimulus + hs = zeroy; % incorporate spike history terms + for l=1:V.Nspikehist, hs = hs+RateParams(end-V.Nspikehist+l)*S.h(:,:,l); end + s = repmat(xk,V.Nparticles,1) + hs; + + f_kdt = exp(s)*V.dt; % shorthand + ef = exp(f_kdt); % shorthand + lik = -sum(S.w_b(sp).*log(1-1./ef(sp)))... % liklihood + +sum(S.w_b(nosp).*f_kdt(nosp)); + + if nargout > 1 % if gradobj=on + dlik = -x(:,sp)*(S.w_b(sp).*f_kdt(sp)./(ef(sp)-1))... %gradient of lik + + x(:,nosp)*(S.w_b(nosp).*f_kdt(nosp)); + end + end %function f_bko + + function [lik dlik]= f_bk(RateParams) % get lik and grad + + xk = RateParams'*V.x; % filtered stimulus + hs = zeroy; % incorporate spike history terms + for l=1:V.Nspikehist, hs = hs+E.omega*S.h(:,:,l); end + s = repmat(xk,V.Nparticles,1) + hs; + + f_kdt = exp(s)*V.dt; % shorthand + ef = exp(f_kdt); % shorthand + lik = -sum(S.w_b(sp).*log(1-1./ef(sp)))... % liklihood + +sum(S.w_b(nosp).*f_kdt(nosp)); + + if nargout > 1 % if gradobj=on + dlik = -x(:,sp)*(S.w_b(sp).*f_kdt(sp)./( ef(sp)-1))... % gradient of lik + + x(:,nosp)*(S.w_b(nosp).*f_kdt(nosp)); + end + end %function f_bko + + %% MLE for calcium parameters + if V.est_c == true && ~V.CaBaselineDrift + + fprintf('estimating calcium parammeters\n') + if V.est_t == 0 + [ve_x fval] = quadprog(M.Q(2:3,2:3), M.L(2:3),[],[],[],[],[0 0],[inf inf],[E.A E.C_0/E.tau_c]+eps,optionsQP); + Enew.tau_c = E.tau_c; + Enew.A = ve_x(1); + Enew.C_0 = ve_x(2)/E.tau_c; + else + [ve_x fval] = quadprog(M.Q, M.L,[],[],[],[],[0 0 0],[inf inf inf],[1/E.tau_c E.A E.C_0/E.tau_c]+eps,optionsQP); + Enew.tau_c = 1/ve_x(1); + Enew.A = ve_x(2); + Enew.C_0 = ve_x(3)/ve_x(1); + end + fval = M.K/2 + fval; % variance + Enew.sigma_c= sqrt(fval/(M.J*V.dt)); % factor in dt + Enew.lik_c = - fval/(Enew.sigma_c*sqrt(V.dt)) - M.J*log(Enew.sigma_c); + lik = [lik Enew.lik_c]; + + Enew.a = V.dt/E.tau_c; % for brevity + Enew.sig2_c = E.sigma_c^2*V.dt; % for brevity + elseif V.CaBaselineDrift + warning('calcium parameter estimation is not yet implemented for baseline drift models, and will be skipped'); + end + + % % %% MLE for spike history parameters + % % for m=1:V.Nspikehist + % % Enew.sigma_h(m)= sum(M.v{m})/V.T; + % % end + + %% MLE for observation parameters + + if V.est_F == true + keyboard + fprintf('estimating observation parammeters\n') + ab_0 = [E.alpha E.beta]; + [Enew.lik_o ab] = f1_ab(ab_0); + + Enew.alpha = ab(1); + Enew.beta = ab(2); + Enew.zeta=E.zeta*ab(3); + Enew.gamma = E.gamma*ab(3); + lik = [lik Enew.lik_o]; + + % Sbar = Hill_v1(E,sum(S.w_b.*S.C)); + % siginv = 1./sqrt([E.gamma*Sbar+E.zeta]); + % + % ab_1 = [Sbar.*siginv; siginv]'\(F.*siginv)'; + % resid = F - ab_1(1)*Sbar - ab_1(2); + % + % X=[Sbar; 1+0*Sbar]'; + % H=X'*X; + % f=-resid*X; + % zg_1 = quadprog(H,f,-X,zeros(1,V.T)); + % zg_1 = [Sbar; 1+0*resid]'\resid'; + % + % Enew.alpha = ab_1(1); + % Enew.beta = ab_1(2); + % Enew.zeta = zg_1(1); %if Enew.zeta<0, Enew.zeta=0; end + % Enew.gamma = zg_1(2); + + end + + function [lik x] = f1_ab(ab_o) + %find MLE for {alpha, beta and gamma/zeta} + %THIS EXPLICITLY ASSUMES WEIGHTS w_b ARE SUM=1 NORMALIZED (!) + pfS=Hill_v1(E,S.C); + pfV=E.gamma*pfS+E.zeta; + % minimize quadratic form of E[(F - ab(1)*pfS - ab(2))^2/pfV] + % taken as weighted average over all particles (Fmean) + f1_abn=sum(sum(S.w_b,2)); % normalization + + f1_abH(1,1) = sum(sum(pfS.^2./pfV.*S.w_b,2));% QF + f1_abH(2,2) = sum(sum(S.w_b./pfV,2)); + f1_abH(1,2) = sum(sum(pfS./pfV.*S.w_b,2)); + f1_abH(2,1) = f1_abH(1,2); + + f1_abf=[0;0]; f1_abc=0; % LF and offset + for i=1:size(pfS,1) % over particles + f1_abf(1) = f1_abf(1) - sum(F(:)'.*pfS(i,:)./pfV(i,:).*S.w_b(i,:)); + f1_abf(2) = f1_abf(2) - sum(F(:)'./pfV(i,:).*S.w_b(i,:)); + + f1_abc = f1_abc + sum(F(:)'.^2./pfV(i,:).*S.w_b(i,:)); + end + + % solve QP given ab(1)>0, no bound on ab(2) + [x lik] = quadprog(f1_abH,f1_abf,[],[],[],[],[0 -inf],[inf inf],[],optionsQP); + + lik=(lik+f1_abc/2); % variance + x(3)=lik/f1_abn; % estimate gamma_new/gamma + + lik=-lik-sum(sum(log(x(3)*pfV).*S.w_b,2))/2; + end %function f_ab + + Enew.lik=[E.lik sum(lik)]; + + end + +end \ No newline at end of file diff --git a/z1.m b/z1.m new file mode 100755 index 0000000..17a7bd5 --- /dev/null +++ b/z1.m @@ -0,0 +1,11 @@ +function x = z1(y) +% linear normalize between 0 and 1 +x = (y-min(y(:)))/(max(y(:))-min(y(:)))+eps; + +% for multidimensional stuff, this normalizes each column to between 0 and +% 1 independent of other columns + +% T=length(y); +% y=y'; +% miny=min(y); +% x = (y-repmat(miny,T,1))./repmat(max(y)-min(y),T,1); \ No newline at end of file