DISFA and SEMAINE with new interface.

This commit is contained in:
Tadas Baltrusaitis 2017-12-08 17:22:30 +00:00
parent 8e12af6c44
commit 9fec91e997
5 changed files with 40 additions and 74 deletions

View file

@ -56,7 +56,7 @@ namespace Utilities
public: public:
// The constructor for the recorder, need to specify if we are recording a sequence or not // The constructor for the recorder, need to specify if we are recording a sequence or not, in_filename should be just the name and not contain extensions
RecorderOpenFace(const std::string in_filename, RecorderOpenFaceParameters parameters, std::vector<std::string>& arguments); RecorderOpenFace(const std::string in_filename, RecorderOpenFaceParameters parameters, std::vector<std::string>& arguments);
~RecorderOpenFace(); ~RecorderOpenFace();

View file

@ -76,7 +76,7 @@ RecorderOpenFace::RecorderOpenFace(const std::string in_filename, RecorderOpenFa
{ {
// From the filename, strip out the name without directory and extension // From the filename, strip out the name without directory and extension
filename = path(string(in_filename)).replace_extension("").filename().string(); filename = in_filename;
// Consuming the input arguments // Consuming the input arguments
bool* valid = new bool[arguments.size()]; bool* valid = new bool[arguments.size()];
@ -140,7 +140,7 @@ RecorderOpenFace::RecorderOpenFace(const std::string in_filename, RecorderOpenFa
metadata_file << "Input:" << in_filename << endl; metadata_file << "Input:" << in_filename << endl;
// Create the required individual recorders, CSV, HOG, aligned, video // Create the required individual recorders, CSV, HOG, aligned, video
csv_filename = (path(record_root) / path(filename).replace_extension(".csv")).string(); csv_filename = (path(record_root) / path(filename).concat(".csv")).string();
metadata_file << "Output csv:" << csv_filename << endl; metadata_file << "Output csv:" << csv_filename << endl;
// Consruct HOG recorder here // Consruct HOG recorder here

View file

@ -1,4 +1,4 @@
function [ labels, valid_ids, vid_ids ] = extract_SEMAINE_labels( SEMAINE_dir, recs, aus ) function [ labels, valid_ids, vid_ids, vid_names ] = extract_SEMAINE_labels( SEMAINE_dir, recs, aus )
%EXTRACT_SEMAINE_LABELS Summary of this function goes here %EXTRACT_SEMAINE_LABELS Summary of this function goes here
% Detailed explanation goes here % Detailed explanation goes here
@ -16,6 +16,7 @@ function [ labels, valid_ids, vid_ids ] = extract_SEMAINE_labels( SEMAINE_dir,
labels = cell(numel(recs), 1); labels = cell(numel(recs), 1);
valid_ids = cell(numel(recs), 1); valid_ids = cell(numel(recs), 1);
vid_names = cell(numel(recs), 1);
vid_ids = zeros(numel(recs), 2); vid_ids = zeros(numel(recs), 2);
for i=1:numel(recs) for i=1:numel(recs)
@ -24,6 +25,9 @@ function [ labels, valid_ids, vid_ids ] = extract_SEMAINE_labels( SEMAINE_dir,
vid_ids(i,:) = dlmread([SEMAINE_dir, '/', recs{i}, '.txt'], ' '); vid_ids(i,:) = dlmread([SEMAINE_dir, '/', recs{i}, '.txt'], ' ');
vid_names_c = dir([SEMAINE_dir, '/', recs{i}, '/*.avi']);
[~, vid_names{i},~] = fileparts(vid_names_c.name);
xml_file = [SEMAINE_dir, recs{i}, '\' file.name]; xml_file = [SEMAINE_dir, recs{i}, '\' file.name];
[root_xml, name_xml, ~] = fileparts(xml_file); [root_xml, name_xml, ~] = fileparts(xml_file);

View file

@ -18,30 +18,23 @@ end
videos = dir([DISFA_dir, '*.avi']); videos = dir([DISFA_dir, '*.avi']);
output = 'out_DISFA/'; output = 'out_DISFA/';
if(~exist(output, 'file'))
mkdir(output);
end
%% %%
% Do it in parrallel for speed (replace the parfor with for if no parallel % Do it in parrallel for speed (replace the parfor with for if no parallel
% toolbox is available) % toolbox is available)
parfor v = 1:numel(videos) % parfor v = 1:numel(videos)
%
vid_file = [DISFA_dir, videos(v).name]; % vid_file = [DISFA_dir, videos(v).name];
%
[~, name, ~] = fileparts(vid_file); % command = sprintf('%s -f "%s" -out_dir "%s" -aus ', executable, vid_file, output);
%
% where to output tracking results % if(isunix)
output_file = [output name '_au.txt']; % unix(command, '-echo');
command = [executable ' -f "' vid_file '" -of "' output_file '" -q -no2Dfp -no3Dfp -noMparams -noPose -noGaze']; % else
% dos(command);
if(isunix) % end
unix(command, '-echo'); %
else % end
dos(command);
end
end
%% Now evaluate the predictions %% Now evaluate the predictions
@ -70,30 +63,32 @@ for i=1:numel(label_folders)
label_ids = cat(1, label_ids, repmat(user_id, size(labels,1),1)); label_ids = cat(1, label_ids, repmat(user_id, size(labels,1),1));
end end
preds_files = dir([prediction_dir, '*SN*.txt']); preds_files = dir([prediction_dir, '*SN*.csv']);
tab = readtable([prediction_dir, preds_files(1).name]); tab = readtable([prediction_dir, preds_files(1).name]);
column_names = tab.Properties.VariableNames; column_names = tab.Properties.VariableNames;
aus_pred_int = []; aus_pred_int = [];
for c=3:numel(column_names) au_inds_in_file = [];
for c=1:numel(column_names)
if(strfind(column_names{c}, '_r') > 0) if(strfind(column_names{c}, '_r') > 0)
aus_pred_int = cat(1, aus_pred_int, int32(str2num(column_names{c}(3:end-2)))); aus_pred_int = cat(1, aus_pred_int, int32(str2num(column_names{c}(3:end-2))));
au_inds_in_file = cat(1, au_inds_in_file, c);
end end
end end
inds_au = zeros(numel(AUs_disfa),1); inds_au = zeros(numel(AUs_disfa),1);
for ind=1:numel(AUs_disfa) for ind=1:numel(AUs_disfa)
inds_au(ind) = find(aus_pred_int==AUs_disfa(ind)); inds_au(ind) = au_inds_in_file(aus_pred_int==AUs_disfa(ind));
end end
preds_all = zeros(size(labels_all,1), numel(AUs_disfa)); preds_all = zeros(size(labels_all,1), numel(AUs_disfa));
for i=1:numel(preds_files) for i=1:numel(preds_files)
preds = dlmread([prediction_dir, preds_files(i).name], ',', 1, 0); preds = dlmread([prediction_dir, preds_files(i).name], ',', 1, 0);
preds = preds(:,5:5+numel(aus_pred_int)-1); %preds = preds(:,5:5+numel(aus_pred_int)-1);
user_id = str2num(preds_files(i).name(end - 14:end-12)); user_id = str2num(preds_files(i).name(end - 11:end-9));
rel_ids = label_ids == user_id; rel_ids = label_ids == user_id;
preds_all(rel_ids,:) = preds(:,inds_au); preds_all(rel_ids,:) = preds(:,inds_au);
end end

View file

@ -5,10 +5,6 @@ find_SEMAINE;
out_loc = './out_SEMAINE/'; out_loc = './out_SEMAINE/';
if(~exist(out_loc, 'dir'))
mkdir(out_loc);
end
if(isunix) if(isunix)
executable = '"../../build/bin/FeatureExtraction"'; executable = '"../../build/bin/FeatureExtraction"';
else else
@ -24,14 +20,9 @@ parfor f1=1:numel(devel_recs)
f1_dir = devel_recs{f1}; f1_dir = devel_recs{f1};
command = [executable, ' -fx 800 -fy 800 -q -no2Dfp -no3Dfp -noMparams -noPose -noGaze '];
curr_vid = [SEMAINE_dir, f1_dir, '/', vid_file.name]; curr_vid = [SEMAINE_dir, f1_dir, '/', vid_file.name];
name = f1_dir; command = sprintf('%s -aus -f "%s" -out_dir "%s" ', executable, curr_vid, out_loc);
output_aus = [out_loc name '.au.txt'];
command = cat(2, command, [' -f "' curr_vid '" -of "' output_aus, '"']);
if(isunix) if(isunix)
unix(command, '-echo'); unix(command, '-echo');
@ -43,26 +34,19 @@ parfor f1=1:numel(devel_recs)
end end
%% Actual model evaluation %% Actual model evaluation
[ labels, valid_ids, vid_ids ] = extract_SEMAINE_labels(SEMAINE_dir, devel_recs, aus_SEMAINE); [ labels, valid_ids, vid_ids, vid_names ] = extract_SEMAINE_labels(SEMAINE_dir, devel_recs, aus_SEMAINE);
labels_gt = cat(1, labels{:}); labels_gt = cat(1, labels{:});
%% Identifying which column IDs correspond to which AU %% Identifying which column IDs correspond to which AU
tab = readtable([out_loc, devel_recs{1}, '.au.txt']); tab = readtable([out_loc, vid_names{1}, '.csv']);
column_names = tab.Properties.VariableNames; column_names = tab.Properties.VariableNames;
% As there are both classes and intensities list and evaluate both of them % As there are both classes and intensities list and evaluate both of them
aus_pred_int = [];
aus_pred_class = []; aus_pred_class = [];
inds_int_in_file = [];
inds_class_in_file = []; inds_class_in_file = [];
for c=1:numel(column_names) for c=1:numel(column_names)
if(strfind(column_names{c}, '_r') > 0)
aus_pred_int = cat(1, aus_pred_int, int32(str2num(column_names{c}(3:end-2))));
inds_int_in_file = cat(1, inds_int_in_file, c);
end
if(strfind(column_names{c}, '_c') > 0) if(strfind(column_names{c}, '_c') > 0)
aus_pred_class = cat(1, aus_pred_class, int32(str2num(column_names{c}(3:end-2)))); aus_pred_class = cat(1, aus_pred_class, int32(str2num(column_names{c}(3:end-2))));
inds_class_in_file = cat(1, inds_class_in_file, c); inds_class_in_file = cat(1, inds_class_in_file, c);
@ -70,37 +54,20 @@ for c=1:numel(column_names)
end end
%% %%
inds_au_int = zeros(size(aus_SEMAINE));
inds_au_class = zeros(size(aus_SEMAINE)); inds_au_class = zeros(size(aus_SEMAINE));
for ind=1:numel(aus_SEMAINE)
if(~isempty(find(aus_pred_int==aus_SEMAINE(ind), 1)))
inds_au_int(ind) = find(aus_pred_int==aus_SEMAINE(ind));
end
end
for ind=1:numel(aus_SEMAINE) for ind=1:numel(aus_SEMAINE)
if(~isempty(find(aus_pred_class==aus_SEMAINE(ind), 1))) if(~isempty(find(aus_pred_class==aus_SEMAINE(ind), 1)))
inds_au_class(ind) = find(aus_pred_class==aus_SEMAINE(ind)); inds_au_class(ind) = inds_class_in_file(aus_pred_class==aus_SEMAINE(ind));
end end
end end
preds_all_class = []; preds_all = [];
preds_all_int = []; for i=1:numel(vid_names)
for i=1:numel(devel_recs)
fname = [out_loc, devel_recs{i}, '.au.txt']; fname = [out_loc, vid_names{i}, '.csv'];
preds = dlmread(fname, ',', 1, 0); preds = dlmread(fname, ',', 1, 0);
preds_all = cat(1, preds_all, preds(vid_ids(i,1):vid_ids(i,2) - 1, :));
% Read all of the intensity AUs
preds_int = preds(vid_ids(i,1):vid_ids(i,2) - 1, inds_int_in_file);
% Read all of the classification AUs
preds_class = preds(vid_ids(i,1):vid_ids(i,2) - 1, inds_class_in_file);
preds_all_class = cat(1, preds_all_class, preds_class);
preds_all_int = cat(1, preds_all_int, preds_int);
end end
%% %%
@ -109,10 +76,10 @@ f1s = zeros(1, numel(aus_SEMAINE));
for au = 1:numel(aus_SEMAINE) for au = 1:numel(aus_SEMAINE)
if(inds_au_class(au) ~= 0) if(inds_au_class(au) ~= 0)
tp = sum(labels_gt(:,au) == 1 & preds_all_class(:, inds_au_class(au)) == 1); tp = sum(labels_gt(:,au) == 1 & preds_all(:, inds_au_class(au)) == 1);
fp = sum(labels_gt(:,au) == 0 & preds_all_class(:, inds_au_class(au)) == 1); fp = sum(labels_gt(:,au) == 0 & preds_all(:, inds_au_class(au)) == 1);
fn = sum(labels_gt(:,au) == 1 & preds_all_class(:, inds_au_class(au)) == 0); fn = sum(labels_gt(:,au) == 1 & preds_all(:, inds_au_class(au)) == 0);
tn = sum(labels_gt(:,au) == 0 & preds_all_class(:, inds_au_class(au)) == 0); tn = sum(labels_gt(:,au) == 0 & preds_all(:, inds_au_class(au)) == 0);
precision = tp./(tp+fp); precision = tp./(tp+fp);
recall = tp./(tp+fn); recall = tp./(tp+fn);