5959176921
- face validator is a retrained CNN now - starting retiring CLM-Z from OpenFace
632 lines
20 KiB
Matlab
632 lines
20 KiB
Matlab
function [net, stats] = cnn_train_reg(net, imdb, getBatch, varargin)
|
|
%cnn_train_reg An example implementation of SGD for training CNNs
|
|
% CNN_TRAIN() is an example learner implementing stochastic
|
|
% gradient descent with momentum to train a CNN. It can be used
|
|
% with different datasets and tasks by providing a suitable
|
|
% getBatch function.
|
|
%
|
|
% The function automatically restarts after each training epoch by
|
|
% checkpointing.
|
|
%
|
|
% The function supports training on CPU or on one or more GPUs
|
|
% (specify the list of GPU IDs in the `gpus` option).
|
|
|
|
% Copyright (C) 2014-16 Andrea Vedaldi.
|
|
% All rights reserved.
|
|
%
|
|
% This file is part of the VLFeat library and is made available under
|
|
% the terms of the BSD license (see the COPYING file).
|
|
|
|
% This is a modified version for regression using the CNN_TRAIN from
|
|
% MatConvNet
|
|
|
|
addpath(fullfile(vl_rootnn, 'examples'));
|
|
|
|
opts.expDir = fullfile('data','exp') ;
|
|
opts.continue = true ;
|
|
opts.batchSize = 256 ;
|
|
opts.numSubBatches = 1 ;
|
|
opts.train = [] ;
|
|
opts.val = [] ;
|
|
opts.gpus = [] ;
|
|
opts.epochSize = inf;
|
|
opts.prefetch = false ;
|
|
opts.numEpochs = 300 ;
|
|
opts.learningRate = 0.001 ;
|
|
opts.weightDecay = 0.0005 ;
|
|
|
|
opts.solver = [] ; % Empty array means use the default SGD solver
|
|
[opts, varargin] = vl_argparse(opts, varargin) ;
|
|
if ~isempty(opts.solver)
|
|
assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,...
|
|
'Invalid solver; expected a function handle with two outputs.') ;
|
|
% Call without input arguments, to get default options
|
|
opts.solverOpts = opts.solver() ;
|
|
end
|
|
|
|
opts.momentum = 0.9 ;
|
|
opts.saveSolverState = true ;
|
|
opts.nesterovUpdate = false ;
|
|
opts.randomSeed = 0 ;
|
|
opts.memoryMapFile = fullfile(tempdir, 'matconvnet.bin') ;
|
|
opts.profile = false ;
|
|
opts.parameterServer.method = 'mmap' ;
|
|
opts.parameterServer.prefix = 'mcn' ;
|
|
|
|
opts.conserveMemory = true ;
|
|
opts.backPropDepth = +inf ;
|
|
opts.sync = false ;
|
|
opts.cudnn = true ;
|
|
opts.errorFunction = 'regression' ;
|
|
opts.errorLabels = {} ;
|
|
opts.plotDiagnostics = false ;
|
|
opts.plotStatistics = true;
|
|
opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change
|
|
opts = vl_argparse(opts, varargin) ;
|
|
|
|
if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
|
|
if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end
|
|
if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
|
|
if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train)
|
|
opts.train = [] ;
|
|
end
|
|
if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val)
|
|
opts.val = [] ;
|
|
end
|
|
|
|
hasError = true ;
|
|
opts.errorLabels = {'correlation', 'rmse'};
|
|
|
|
% -------------------------------------------------------------------------
|
|
% Initialization
|
|
% -------------------------------------------------------------------------
|
|
|
|
net = vl_simplenn_tidy(net); % fill in some eventually missing values
|
|
net.layers{end-1}.precious = 1; % do not remove predictions, used for error
|
|
vl_simplenn_display(net, 'batchSize', opts.batchSize) ;
|
|
|
|
evaluateMode = isempty(opts.train) ;
|
|
if ~evaluateMode
|
|
for i=1:numel(net.layers)
|
|
J = numel(net.layers{i}.weights) ;
|
|
if ~isfield(net.layers{i}, 'learningRate')
|
|
net.layers{i}.learningRate = ones(1, J) ;
|
|
end
|
|
if ~isfield(net.layers{i}, 'weightDecay')
|
|
net.layers{i}.weightDecay = ones(1, J) ;
|
|
end
|
|
end
|
|
end
|
|
|
|
state.getBatch = getBatch ;
|
|
stats = [] ;
|
|
|
|
% -------------------------------------------------------------------------
|
|
% Train and validate
|
|
% -------------------------------------------------------------------------
|
|
|
|
modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep));
|
|
modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ;
|
|
|
|
start = opts.continue * findLastCheckpoint(opts.expDir) ;
|
|
if start >= 1
|
|
fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
|
|
[net, state, stats] = loadState(modelPath(start)) ;
|
|
else
|
|
state = [] ;
|
|
end
|
|
|
|
for epoch=start+1:opts.numEpochs
|
|
|
|
% Set the random seed based on the epoch and opts.randomSeed.
|
|
% This is important for reproducibility, including when training
|
|
% is restarted from a checkpoint.
|
|
|
|
rng(epoch + opts.randomSeed) ;
|
|
prepareGPUs(opts, epoch == start+1) ;
|
|
|
|
% Train for one epoch.
|
|
params = opts ;
|
|
params.epoch = epoch ;
|
|
params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
|
|
params.train = opts.train(randperm(numel(opts.train))) ; % shuffle
|
|
params.train = params.train(1:min(opts.epochSize, numel(opts.train)));
|
|
params.imdb = imdb ;
|
|
params.getBatch = getBatch ;
|
|
|
|
if numel(params.gpus) <= 1
|
|
[net, state] = processEpoch(net, state, params, 'train') ;
|
|
[net, state] = processEpoch(net, state, params, 'val') ;
|
|
if ~evaluateMode
|
|
saveState(modelPath(epoch), net, state) ;
|
|
end
|
|
lastStats = state.stats ;
|
|
else
|
|
spmd
|
|
[net, state] = processEpoch(net, state, params, 'train') ;
|
|
[net, state] = processEpoch(net, state, params, 'val') ;
|
|
if labindex == 1 && ~evaluateMode
|
|
saveState(modelPath(epoch), net, state) ;
|
|
end
|
|
lastStats = state.stats ;
|
|
end
|
|
lastStats = accumulateStats(lastStats) ;
|
|
end
|
|
|
|
stats.train(epoch) = lastStats.train ;
|
|
stats.val(epoch) = lastStats.val ;
|
|
clear lastStats ;
|
|
if ~evaluateMode
|
|
saveStats(modelPath(epoch), stats) ;
|
|
end
|
|
|
|
if params.plotStatistics
|
|
switchFigure(1) ; clf ;
|
|
plots = setdiff(...
|
|
cat(2,...
|
|
fieldnames(stats.train)', ...
|
|
fieldnames(stats.val)'), {'num', 'time'}) ;
|
|
for p = plots
|
|
p = char(p) ;
|
|
values = zeros(0, epoch) ;
|
|
leg = {} ;
|
|
for f = {'train', 'val'}
|
|
f = char(f) ;
|
|
if isfield(stats.(f), p)
|
|
tmp = [stats.(f).(p)] ;
|
|
values(end+1,:) = tmp(1,:)' ;
|
|
leg{end+1} = f ;
|
|
end
|
|
end
|
|
subplot(1,numel(plots),find(strcmp(p,plots))) ;
|
|
plot(1:epoch, values','o-') ;
|
|
xlabel('epoch') ;
|
|
title(p) ;
|
|
legend(leg{:}) ;
|
|
grid on ;
|
|
end
|
|
drawnow ;
|
|
print(1, modelFigPath, '-dpdf') ;
|
|
end
|
|
|
|
if ~isempty(opts.postEpochFn)
|
|
if nargout(opts.postEpochFn) == 0
|
|
opts.postEpochFn(net, params, state) ;
|
|
else
|
|
lr = opts.postEpochFn(net, params, state) ;
|
|
if ~isempty(lr), opts.learningRate = lr; end
|
|
if opts.learningRate == 0, break; end
|
|
end
|
|
end
|
|
end
|
|
|
|
% Return the best performing model
|
|
[~,best_epoch] = min(cat(1,stats.val.rmse));
|
|
fprintf('%s: Best model in epoch %d\n', mfilename, best_epoch) ;
|
|
[net, state, stats] = loadState(modelPath(best_epoch)) ;
|
|
|
|
% With multiple GPUs, return one copy
|
|
if isa(net, 'Composite'), net = net{1} ; end
|
|
|
|
% -------------------------------------------------------------------------
|
|
function [net, state] = processEpoch(net, state, params, mode)
|
|
% -------------------------------------------------------------------------
|
|
% Note that net is not strictly needed as an output argument as net
|
|
% is a handle class. However, this fixes some aliasing issue in the
|
|
% spmd caller.
|
|
|
|
% initialize with momentum 0
|
|
if isempty(state) || isempty(state.solverState)
|
|
for i = 1:numel(net.layers)
|
|
state.solverState{i} = cell(1, numel(net.layers{i}.weights)) ;
|
|
state.solverState{i}(:) = {0} ;
|
|
end
|
|
end
|
|
|
|
% move CNN to GPU as needed
|
|
numGpus = numel(params.gpus) ;
|
|
if numGpus >= 1
|
|
net = vl_simplenn_move(net, 'gpu') ;
|
|
for i = 1:numel(state.solverState)
|
|
for j = 1:numel(state.solverState{i})
|
|
s = state.solverState{i}{j} ;
|
|
if isnumeric(s)
|
|
state.solverState{i}{j} = gpuArray(s) ;
|
|
elseif isstruct(s)
|
|
state.solverState{i}{j} = structfun(@gpuArray, s, 'UniformOutput', false) ;
|
|
end
|
|
end
|
|
end
|
|
end
|
|
if numGpus > 1
|
|
parserv = ParameterServer(params.parameterServer) ;
|
|
vl_simplenn_start_parserv(net, parserv) ;
|
|
else
|
|
parserv = [] ;
|
|
end
|
|
|
|
% profile
|
|
if params.profile
|
|
if numGpus <= 1
|
|
profile clear ;
|
|
profile on ;
|
|
else
|
|
mpiprofile reset ;
|
|
mpiprofile on ;
|
|
end
|
|
end
|
|
|
|
subset = params.(mode) ;
|
|
num = 0 ;
|
|
stats.num = 0 ; % return something even if subset = []
|
|
stats.time = 0 ;
|
|
adjustTime = 0 ;
|
|
res = [] ;
|
|
|
|
preds_all = [];
|
|
labels_all = [];
|
|
|
|
err = 0;
|
|
|
|
start = tic ;
|
|
for t=1:params.batchSize:numel(subset)
|
|
fprintf('%s: epoch %02d: %3d/%3d:', mode, params.epoch, ...
|
|
fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize)) ;
|
|
batchSize = min(params.batchSize, numel(subset) - t + 1) ;
|
|
|
|
for s=1:params.numSubBatches
|
|
% get this image batch and prefetch the next
|
|
batchStart = t + (labindex-1) + (s-1) * numlabs ;
|
|
batchEnd = min(t+params.batchSize-1, numel(subset)) ;
|
|
batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
|
|
num = num + numel(batch) ;
|
|
if numel(batch) == 0, continue ; end
|
|
|
|
[im, labels] = params.getBatch(params.imdb, batch) ;
|
|
|
|
if params.prefetch
|
|
if s == params.numSubBatches
|
|
batchStart = t + (labindex-1) + params.batchSize ;
|
|
batchEnd = min(t+2*params.batchSize-1, numel(subset)) ;
|
|
else
|
|
batchStart = batchStart + numlabs ;
|
|
end
|
|
nextBatch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
|
|
params.getBatch(params.imdb, nextBatch) ;
|
|
end
|
|
|
|
if numGpus >= 1
|
|
im = gpuArray(im) ;
|
|
end
|
|
|
|
if strcmp(mode, 'train')
|
|
dzdy = 1 ;
|
|
evalMode = 'normal' ;
|
|
else
|
|
dzdy = [] ;
|
|
evalMode = 'test' ;
|
|
end
|
|
net.layers{end}.class = labels ;
|
|
res = vl_simplenn(net, im, dzdy, res, ...
|
|
'accumulate', s ~= 1, ...
|
|
'mode', evalMode, ...
|
|
'conserveMemory', params.conserveMemory, ...
|
|
'backPropDepth', params.backPropDepth, ...
|
|
'sync', params.sync, ...
|
|
'cudnn', params.cudnn, ...
|
|
'parameterServer', parserv, ...
|
|
'holdOn', s < params.numSubBatches) ;
|
|
|
|
predictions = gather(res(end-1).x) ;
|
|
[~,predictions] = sort(predictions, 3, 'descend') ;
|
|
predictions = squeeze(predictions);
|
|
num_bins = size(predictions,1);
|
|
predictions = predictions(1,:);
|
|
|
|
% Convert the class labels into the continuous values
|
|
labels = unQuantizeContinuous(squeeze(labels), 0, 3, num_bins)';
|
|
|
|
predictions = unQuantizeContinuous(squeeze(predictions), 0, 3, num_bins)';
|
|
preds_all = cat(1, preds_all, predictions);
|
|
labels_all = cat(1, labels_all, labels);
|
|
|
|
err = [err(1)+sum(double(gather(res(end).x)));...
|
|
corr(labels_all, preds_all);...
|
|
sqrt(mean((labels_all-preds_all).^2))];
|
|
end
|
|
|
|
% accumulate gradient
|
|
if strcmp(mode, 'train')
|
|
if ~isempty(parserv), parserv.sync() ; end
|
|
[net, res, state] = accumulateGradients(net, res, state, params, batchSize, parserv) ;
|
|
end
|
|
|
|
% get statistics
|
|
time = toc(start) + adjustTime ;
|
|
batchTime = time - stats.time ;
|
|
|
|
stats = extractStats(net, params, [err(1)/num; err(2); err(3)]) ;
|
|
stats.num = num ;
|
|
stats.time = time ;
|
|
currentSpeed = batchSize / batchTime ;
|
|
averageSpeed = (t + batchSize - 1) / time ;
|
|
if t == 3*params.batchSize + 1
|
|
% compensate for the first three iterations, which are outliers
|
|
adjustTime = 4*batchTime - time ;
|
|
stats.time = time + adjustTime ;
|
|
end
|
|
|
|
fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ;
|
|
for f = setdiff(fieldnames(stats)', {'num', 'time'})
|
|
f = char(f) ;
|
|
fprintf(' %s: %.3f', f, stats.(f)) ;
|
|
end
|
|
fprintf('\n') ;
|
|
|
|
% collect diagnostic statistics
|
|
if strcmp(mode, 'train') && params.plotDiagnostics
|
|
switchFigure(2) ; clf ;
|
|
diagn = [res.stats] ;
|
|
diagnvar = horzcat(diagn.variation) ;
|
|
diagnpow = horzcat(diagn.power) ;
|
|
subplot(2,2,1) ; barh(diagnvar) ;
|
|
set(gca,'TickLabelInterpreter', 'none', ...
|
|
'YTick', 1:numel(diagnvar), ...
|
|
'YTickLabel',horzcat(diagn.label), ...
|
|
'YDir', 'reverse', ...
|
|
'XScale', 'log', ...
|
|
'XLim', [1e-5 1], ...
|
|
'XTick', 10.^(-5:1)) ;
|
|
grid on ; title('Variation');
|
|
subplot(2,2,2) ; barh(sqrt(diagnpow)) ;
|
|
set(gca,'TickLabelInterpreter', 'none', ...
|
|
'YTick', 1:numel(diagnpow), ...
|
|
'YTickLabel',{diagn.powerLabel}, ...
|
|
'YDir', 'reverse', ...
|
|
'XScale', 'log', ...
|
|
'XLim', [1e-5 1e5], ...
|
|
'XTick', 10.^(-5:5)) ;
|
|
grid on ; title('Power');
|
|
subplot(2,2,3); plot(squeeze(res(end-1).x)) ;
|
|
drawnow ;
|
|
end
|
|
end
|
|
|
|
% Save back to state.
|
|
state.stats.(mode) = stats ;
|
|
if params.profile
|
|
if numGpus <= 1
|
|
state.prof.(mode) = profile('info') ;
|
|
profile off ;
|
|
else
|
|
state.prof.(mode) = mpiprofile('info');
|
|
mpiprofile off ;
|
|
end
|
|
end
|
|
if ~params.saveSolverState
|
|
state.solverState = [] ;
|
|
else
|
|
for i = 1:numel(state.solverState)
|
|
for j = 1:numel(state.solverState{i})
|
|
s = state.solverState{i}{j} ;
|
|
if isnumeric(s)
|
|
state.solverState{i}{j} = gather(s) ;
|
|
elseif isstruct(s)
|
|
state.solverState{i}{j} = structfun(@gather, s, 'UniformOutput', false) ;
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
net = vl_simplenn_move(net, 'cpu') ;
|
|
|
|
% -------------------------------------------------------------------------
|
|
function [net, res, state] = accumulateGradients(net, res, state, params, batchSize, parserv)
|
|
% -------------------------------------------------------------------------
|
|
numGpus = numel(params.gpus) ;
|
|
otherGpus = setdiff(1:numGpus, labindex) ;
|
|
|
|
for l=numel(net.layers):-1:1
|
|
for j=numel(res(l).dzdw):-1:1
|
|
|
|
if ~isempty(parserv)
|
|
tag = sprintf('l%d_%d',l,j) ;
|
|
parDer = parserv.pull(tag) ;
|
|
else
|
|
parDer = res(l).dzdw{j} ;
|
|
end
|
|
|
|
if j == 3 && strcmp(net.layers{l}.type, 'bnorm')
|
|
% special case for learning bnorm moments
|
|
thisLR = net.layers{l}.learningRate(j) ;
|
|
net.layers{l}.weights{j} = vl_taccum(...
|
|
1 - thisLR, ...
|
|
net.layers{l}.weights{j}, ...
|
|
thisLR / batchSize, ...
|
|
parDer) ;
|
|
else
|
|
% Standard gradient training.
|
|
thisDecay = params.weightDecay * net.layers{l}.weightDecay(j) ;
|
|
thisLR = params.learningRate * net.layers{l}.learningRate(j) ;
|
|
|
|
if thisLR>0 || thisDecay>0
|
|
% Normalize gradient and incorporate weight decay.
|
|
parDer = vl_taccum(1/batchSize, parDer, ...
|
|
thisDecay, net.layers{l}.weights{j}) ;
|
|
|
|
if isempty(params.solver)
|
|
% Default solver is the optimised SGD.
|
|
% Update momentum.
|
|
state.solverState{l}{j} = vl_taccum(...
|
|
params.momentum, state.solverState{l}{j}, ...
|
|
-1, parDer) ;
|
|
|
|
% Nesterov update (aka one step ahead).
|
|
if params.nesterovUpdate
|
|
delta = params.momentum * state.solverState{l}{j} - parDer ;
|
|
else
|
|
delta = state.solverState{l}{j} ;
|
|
end
|
|
|
|
% Update parameters.
|
|
net.layers{l}.weights{j} = vl_taccum(...
|
|
1, net.layers{l}.weights{j}, ...
|
|
thisLR, delta) ;
|
|
|
|
else
|
|
% call solver function to update weights
|
|
[net.layers{l}.weights{j}, state.solverState{l}{j}] = ...
|
|
params.solver(net.layers{l}.weights{j}, state.solverState{l}{j}, ...
|
|
parDer, params.solverOpts, thisLR) ;
|
|
end
|
|
end
|
|
end
|
|
|
|
% if requested, collect some useful stats for debugging
|
|
if params.plotDiagnostics
|
|
variation = [] ;
|
|
label = '' ;
|
|
switch net.layers{l}.type
|
|
case {'conv','convt'}
|
|
if isnumeric(state.solverState{l}{j})
|
|
variation = thisLR * mean(abs(state.solverState{l}{j}(:))) ;
|
|
end
|
|
power = mean(res(l+1).x(:).^2) ;
|
|
if j == 1 % fiters
|
|
base = mean(net.layers{l}.weights{j}(:).^2) ;
|
|
label = 'filters' ;
|
|
else % biases
|
|
base = sqrt(power) ;%mean(abs(res(l+1).x(:))) ;
|
|
label = 'biases' ;
|
|
end
|
|
variation = variation / base ;
|
|
label = sprintf('%s_%s', net.layers{l}.name, label) ;
|
|
end
|
|
res(l).stats.variation(j) = variation ;
|
|
res(l).stats.power = power ;
|
|
res(l).stats.powerLabel = net.layers{l}.name ;
|
|
res(l).stats.label{j} = label ;
|
|
end
|
|
end
|
|
end
|
|
|
|
% -------------------------------------------------------------------------
|
|
function stats = accumulateStats(stats_)
|
|
% -------------------------------------------------------------------------
|
|
|
|
for s = {'train', 'val'}
|
|
s = char(s) ;
|
|
total = 0 ;
|
|
|
|
% initialize stats stucture with same fields and same order as
|
|
% stats_{1}
|
|
stats__ = stats_{1} ;
|
|
names = fieldnames(stats__.(s))' ;
|
|
values = zeros(1, numel(names)) ;
|
|
fields = cat(1, names, num2cell(values)) ;
|
|
stats.(s) = struct(fields{:}) ;
|
|
|
|
for g = 1:numel(stats_)
|
|
stats__ = stats_{g} ;
|
|
num__ = stats__.(s).num ;
|
|
total = total + num__ ;
|
|
|
|
for f = setdiff(fieldnames(stats__.(s))', 'num')
|
|
f = char(f) ;
|
|
stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;
|
|
|
|
if g == numel(stats_)
|
|
stats.(s).(f) = stats.(s).(f) / total ;
|
|
end
|
|
end
|
|
end
|
|
stats.(s).num = total ;
|
|
end
|
|
|
|
% -------------------------------------------------------------------------
|
|
function stats = extractStats(net, params, errors)
|
|
% -------------------------------------------------------------------------
|
|
stats.objective = errors(1) ;
|
|
for i = 1:numel(params.errorLabels)
|
|
stats.(params.errorLabels{i}) = errors(i+1) ;
|
|
end
|
|
|
|
% -------------------------------------------------------------------------
|
|
function saveState(fileName, net, state)
|
|
% -------------------------------------------------------------------------
|
|
save(fileName, 'net', 'state') ;
|
|
|
|
% -------------------------------------------------------------------------
|
|
function saveStats(fileName, stats)
|
|
% -------------------------------------------------------------------------
|
|
if exist(fileName)
|
|
save(fileName, 'stats', '-append') ;
|
|
else
|
|
save(fileName, 'stats') ;
|
|
end
|
|
|
|
% -------------------------------------------------------------------------
|
|
function [net, state, stats] = loadState(fileName)
|
|
% -------------------------------------------------------------------------
|
|
load(fileName, 'net', 'state', 'stats') ;
|
|
net = vl_simplenn_tidy(net) ;
|
|
if isempty(whos('stats'))
|
|
error('Epoch ''%s'' was only partially saved. Delete this file and try again.', ...
|
|
fileName) ;
|
|
end
|
|
|
|
% -------------------------------------------------------------------------
|
|
function epoch = findLastCheckpoint(modelDir)
|
|
% -------------------------------------------------------------------------
|
|
list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
|
|
tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
|
|
epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
|
|
epoch = max([epoch 0]) ;
|
|
|
|
% -------------------------------------------------------------------------
|
|
function switchFigure(n)
|
|
% -------------------------------------------------------------------------
|
|
if get(0,'CurrentFigure') ~= n
|
|
try
|
|
set(0,'CurrentFigure',n) ;
|
|
catch
|
|
figure(n) ;
|
|
end
|
|
end
|
|
|
|
% -------------------------------------------------------------------------
|
|
function clearMex()
|
|
% -------------------------------------------------------------------------
|
|
%clear vl_tmove vl_imreadjpeg ;
|
|
disp('Clearing mex files') ;
|
|
clear mex ;
|
|
clear vl_tmove vl_imreadjpeg ;
|
|
|
|
% -------------------------------------------------------------------------
|
|
function prepareGPUs(params, cold)
|
|
% -------------------------------------------------------------------------
|
|
numGpus = numel(params.gpus) ;
|
|
if numGpus > 1
|
|
% check parallel pool integrity as it could have timed out
|
|
pool = gcp('nocreate') ;
|
|
if ~isempty(pool) && pool.NumWorkers ~= numGpus
|
|
delete(pool) ;
|
|
end
|
|
pool = gcp('nocreate') ;
|
|
if isempty(pool)
|
|
parpool('local', numGpus) ;
|
|
cold = true ;
|
|
end
|
|
end
|
|
if numGpus >= 1 && cold
|
|
fprintf('%s: resetting GPU\n', mfilename) ;
|
|
clearMex() ;
|
|
if numGpus == 1
|
|
disp(gpuDevice(params.gpus)) ;
|
|
else
|
|
spmd
|
|
clearMex() ;
|
|
disp(gpuDevice(params.gpus(labindex))) ;
|
|
end
|
|
end
|
|
end
|