sustaining_gazes/matlab_version/pdm_generation/nrsfm-em/kalmansmooth.m

70 lines
2.5 KiB
Matlab

function [xt_n, Pt_n, Ptt1_n, xt_t1, Pt_t1] = kalmansmooth(y, M, maxIter, phi, mu0, sigma0, Q, sigma_sq, p, q, n)
% Adapted from code written by Hrishi Deshpande
% Based on eqn (A3)-(A12) of Appendix A in Shumway, R.H. & Stoffer, D.S. (1982),
% "An approach to time series smoothing and forecasting using the EM algorithm", Journal of Time Series Analysis, 3, 253-264
%- - - - - - - - - - - - - - - -
% Forward steps. Eqns (A3)-(A7).
xt_t1 = zeros(p, n+1); % t = 0, 1, ..., n
Pt_t1 = cell (1, n+1); % t = 0, 1, ..., n
K = cell (1, n+1); % t = 0, 1, ..., n
xt_t = zeros(p, n+1); % t = 0, 1, ..., n
Pt_t = cell (1, n+1); % t = 0, 1, ..., n
t = 0; tIdx = t+1;
xt_t(:,tIdx) = mu0;
Pt_t{tIdx} = sigma0;
for t = 1:n
tIdx = t+1;
xt_t1(:,tIdx) = phi*xt_t(:,tIdx-1); % (A3)
Pt_t1{tIdx} = phi*Pt_t{tIdx-1}*phi' + Q; % (A4)
if 1,
K{tIdx} = Pt_t1{tIdx}*M{tIdx}' * inv(M{tIdx}*Pt_t1{tIdx}*M{tIdx}' + sigma_sq*eye(q)); % (A5)
else
% Using the Matrix Inversion Lemma
% (see http://www-2.cs.cmu.edu/afs/cs.cmu.edu/user/zoubin/www/SALD/week5b.pdf)
invR = eye(q)./sigma_sq; AA = inv(inv(Pt_t1{tIdx}) + M{tIdx}'*invR*M{tIdx}); BB = (invR - invR*M{tIdx}*AA*M{tIdx}'*invR);
K{tIdx} = Pt_t1{tIdx}*M{tIdx}' * BB; % (A5)
end
xt_t(:,tIdx) = xt_t1(:,tIdx) + K{tIdx}*(y(:,tIdx) - M{tIdx}*xt_t1(:,tIdx)); % (A6)
Pt_t{tIdx} = Pt_t1{tIdx} - K{tIdx}*M{tIdx}*Pt_t1{tIdx}; % (A7)
end
%- - - - - - - - - - - - - - - -
% Backward steps. Eqns (A8)-(A10)
Jt = cell (1, n+1); % t = 0, 1, ..., n
xt_n = zeros(p, n+1); % t = 0, 1, ..., n
Pt_n = cell (1, n+1); % t = 0, 1, ..., n
t=n; tIdx = t+1;
xt_n(:,tIdx) = xt_t(:,tIdx); % (A9)
Pt_n{tIdx} = Pt_t{tIdx}; % (A10)
for t=n:-1:1
tIdx = t+1;
Jt{tIdx-1} = Pt_t{tIdx-1}*phi'*inv(Pt_t1{tIdx}); % (A8)
xt_n(:,tIdx-1) = xt_t(:,tIdx-1) + Jt{tIdx-1}*(xt_n(:,tIdx) - phi*xt_t(:,tIdx-1)); % (A9)
Pt_n{tIdx-1} = Pt_t{tIdx-1} + Jt{tIdx-1} * (Pt_n{tIdx} - Pt_t1{tIdx}) * Jt{tIdx-1}'; % (A10)
end
%- - - - - - - - - - - - - - - -
% Backward steps. Eqns (A11)-(A12)
Ptt1_n = cell(1, n+1); % t = 0, 1, ..., n
t = n; tIdx = t+1;
Ptt1_n{tIdx} = (eye(p) - K{tIdx}*M{tIdx}) * phi * Pt_t{tIdx-1}; % (A12)
for t=n:-1:2
tIdx = t+1;
Ptt1_n{tIdx-1} = Pt_t{tIdx-1}*Jt{tIdx-2}' + ...
Jt{tIdx-1}*(Ptt1_n{tIdx} - phi*Pt_t{tIdx-1})*Jt{tIdx-2}'; % (A11)
end