function [P3, S_bar, V, RO, Tr, Z, sigma_sq, phi, Q, mu0, sigma0] = em_sfm(P, MD, K, use_lds, tol, max_em_iter) % Non-Rigid Structure From Motion with Gaussian/LDS Deformation Model % Copyright (c) by Lorenzo Torresani, Stanford University % % Based on the following paper: % % Lorenzo Torresani, Aaron Hertzmann and Christoph Bregler, % Learning Non-Rigid 3D Shape from 2D Motion, NIPS 16, 2003 % http://cs.stanford.edu/~ltorresa/projects/learning-nr-shape/ % % Please refer to this publication if you use this program for % research or for technical applications. % % % INPUT: % % P - (2*T) x J tracking matrix: P([t t+T],:) contains the 2D projections of the J points at time t % MD - T x J missing data binary matrix: MD(t, j)=1 if no valid data is available for point j at time t, 0 otherwise % K - number of deformation basis % use_lds - set to 1 to model deformations using a linear dynamical system; set to 0 otherwise % tol - termination tolerance (proportional change in likelihood) % max_em_iter - maximum number of EM iterations % % % OUTPUT: % % P3 - (3*T) x J 3D-motion matrix: ( P3([t t+T t+2*T],:) contains the 3D coordinates of the J points at time t ) % S_bar - shape average: 3 x J matrix % V - deformation shapes: (3*K) x J matrix ( V((n-1)*3+[1:3],:) contains the n-th deformation basis ) % RO - rotation: cell array ( RO{t} gives the rotation matrix at time t ) % Tr - translation: T x 2 matrix % Z - deformation weights: T x K matrix % sigma_sq - variance of the noise in feature position % phi - LDS transition matrix % Q - LDS state noise matrix % mu0 - initial state mean % sigma0 - initial state variance if mod(size(P,1), 1) ~= 0, fprintf('Error: size(P) must be (2*T)xJ\n'); return; end if (size(P,1)/2 ~= size(MD,1)) | (size(P,2) ~= size(MD,2)) fprintf('Error: Size incompatibility between P and MD\n'); return; end if mod(K, 1) ~= 0, fprintf('Error: K must be an integer value\n'); return; end [T, J] = size(MD); r = 3*(K + 1); % motion rank P_hat = P; % if any of the points are missing, P_hat will be updated during the M-step % uses rank 3 factorization to get a first initialization for rotation and S_bar [R_init, Trvect, S_bar] = rigidfac(P_hat, MD); Tr(:,1) = Trvect(1:T); Tr(:,2) = Trvect(T+1:2*T); R = zeros(2*T, 3); % enforces rotation constraints for t = 1:T, Ru = R_init(t,:); Rv = R_init(T+t,:); Rz = cross(Ru,Rv); if det([Ru;Rv;Rz])<0, Rz = -Rz; end; RO_approx = apprRot([Ru;Rv;Rz]); RO{t} = RO_approx; R(t,:) = RO_approx(1,:); R(t+T,:) = RO_approx(2,:); end; % given the initial estimates of rotation, translation and shape average, it initializes % deformation shapes and weights through LSQ minimization of the reprojection error [V, Z] = init_SB(P_hat, Tr, R, S_bar, K); % initializes sigma_sq E_zz_init = cov(Z); E_zz_init = repmat(E_zz_init, T, 1); sigma_sq = mstep_update_noisevar(P_hat, S_bar, V, Z', E_zz_init, RO, Tr); if use_lds, [phi, mu0, sigma0, Q] = init_lds(P_hat, S_bar, V, R, Tr, sigma_sq); else phi = []; mu0 = []; sigma0 = []; Q = []; end loglik = 0; annealing_const = 60; max_anneal_iter = round(max_em_iter/2); for em_iter=1:max_em_iter, if use_lds, [E_z, E_zz, y, M, xt_n, Pt_n, Ptt1_n, xt_t1, Pt_t1] = estep_lds_compute_Z_distr(P_hat, S_bar, V, R, Tr, phi, mu0, sigma0, Q, sigma_sq); [phi, Q, sigma_sq, mu0, sigma0] = mstep_lds_update(y, M, xt_n, Pt_n, Ptt1_n); else % computes the hidden variables distributions [E_z, E_zz] = estep_compute_Z_distr(P_hat, S_bar, V, R, Tr, sigma_sq); % (Eq 17-18) end Z = E_z'; % updates shape basis [S_bar, V] = mstep_update_shapebasis(P_hat, E_z, E_zz, R, Tr, S_bar, V); % (Eq 21) % fills in missing points if sum(MD(:))>0, P_hat = mstep_update_missingdata(P_hat, MD, S_bar, V, E_z, RO, Tr); % (Eq 25) end % updates rotation [RO, R] = mstep_update_rotation(P_hat, S_bar, V, E_z, E_zz, RO, Tr); % (Eq 24) % updates translation Tr = mstep_update_transl(P_hat, S_bar, V, E_z, RO); % (Eq 23) if ~use_lds, % updates noise variance sigma_sq = mstep_update_noisevar(P_hat, S_bar, V, E_z, E_zz, RO, Tr); % (Eq 22) if em_iter < max_anneal_iter, sigma_sq = sigma_sq * (1 + annealing_const*(1 - em_iter/max_anneal_iter)); end oldloglik = loglik; % computes log likelihood loglik = compute_log_lik(P_hat, S_bar, V, E_z, E_zz, RO, Tr, sigma_sq); fprintf('LogLik(%d): %f\n', em_iter, loglik); if (em_iter <= 2), loglikbase = loglik; elseif (loglik < oldloglik) fprintf('Violation'); % keyboard; elseif 0 & ((loglik-loglikbase)<(1 + tol)*(oldloglik-loglikbase)), fprintf('\n'); break; end else fprintf('Iteration %d/%d\n', em_iter, max_em_iter); end end P3 = zeros(3*T, J); for t = 1:T, z_t = Z(t,:); Rf = [R(t,:); R(t+T,:)]; S = S_bar; for kk = 1:K, S = S+z_t(kk)*V((kk-1)*3+[1:3],:); end; S = RO{t}*S; P3([t t+T t+2*T], :) = S + [Tr(t, [1 2]) -mean(S(3,:))]'*ones(1,J); end