function [net, stats] = cnn_train_reg(net, imdb, getBatch, varargin) %cnn_train_reg An example implementation of SGD for training CNNs % CNN_TRAIN() is an example learner implementing stochastic % gradient descent with momentum to train a CNN. It can be used % with different datasets and tasks by providing a suitable % getBatch function. % % The function automatically restarts after each training epoch by % checkpointing. % % The function supports training on CPU or on one or more GPUs % (specify the list of GPU IDs in the `gpus` option). % Copyright (C) 2014-16 Andrea Vedaldi. % All rights reserved. % % This file is part of the VLFeat library and is made available under % the terms of the BSD license (see the COPYING file). % This is a modified version for regression using the CNN_TRAIN from % MatConvNet addpath(fullfile(vl_rootnn, 'examples')); opts.expDir = fullfile('data','exp') ; opts.continue = true ; opts.batchSize = 256 ; opts.numSubBatches = 1 ; opts.train = [] ; opts.val = [] ; opts.gpus = [] ; opts.epochSize = inf; opts.prefetch = false ; opts.numEpochs = 300 ; opts.learningRate = 0.001 ; opts.weightDecay = 0.0005 ; opts.solver = [] ; % Empty array means use the default SGD solver [opts, varargin] = vl_argparse(opts, varargin) ; if ~isempty(opts.solver) assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,... 'Invalid solver; expected a function handle with two outputs.') ; % Call without input arguments, to get default options opts.solverOpts = opts.solver() ; end opts.momentum = 0.9 ; opts.saveSolverState = true ; opts.nesterovUpdate = false ; opts.randomSeed = 0 ; opts.memoryMapFile = fullfile(tempdir, 'matconvnet.bin') ; opts.profile = false ; opts.parameterServer.method = 'mmap' ; opts.parameterServer.prefix = 'mcn' ; opts.conserveMemory = true ; opts.backPropDepth = +inf ; opts.sync = false ; opts.cudnn = true ; opts.errorFunction = 'regression' ; opts.errorLabels = {} ; opts.plotDiagnostics = false ; opts.plotStatistics = true; opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change opts = vl_argparse(opts, varargin) ; if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train) opts.train = [] ; end if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val) opts.val = [] ; end hasError = true ; opts.errorLabels = {'correlation', 'rmse'}; % ------------------------------------------------------------------------- % Initialization % ------------------------------------------------------------------------- net = vl_simplenn_tidy(net); % fill in some eventually missing values net.layers{end-1}.precious = 1; % do not remove predictions, used for error vl_simplenn_display(net, 'batchSize', opts.batchSize) ; evaluateMode = isempty(opts.train) ; if ~evaluateMode for i=1:numel(net.layers) J = numel(net.layers{i}.weights) ; if ~isfield(net.layers{i}, 'learningRate') net.layers{i}.learningRate = ones(1, J) ; end if ~isfield(net.layers{i}, 'weightDecay') net.layers{i}.weightDecay = ones(1, J) ; end end end state.getBatch = getBatch ; stats = [] ; % ------------------------------------------------------------------------- % Train and validate % ------------------------------------------------------------------------- modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep)); modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ; start = opts.continue * findLastCheckpoint(opts.expDir) ; if start >= 1 fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ; [net, state, stats] = loadState(modelPath(start)) ; else state = [] ; end for epoch=start+1:opts.numEpochs % Set the random seed based on the epoch and opts.randomSeed. % This is important for reproducibility, including when training % is restarted from a checkpoint. rng(epoch + opts.randomSeed) ; prepareGPUs(opts, epoch == start+1) ; % Train for one epoch. params = opts ; params.epoch = epoch ; params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ; params.train = opts.train(randperm(numel(opts.train))) ; % shuffle params.train = params.train(1:min(opts.epochSize, numel(opts.train))); params.imdb = imdb ; params.getBatch = getBatch ; if numel(params.gpus) <= 1 [net, state] = processEpoch(net, state, params, 'train') ; [net, state] = processEpoch(net, state, params, 'val') ; if ~evaluateMode saveState(modelPath(epoch), net, state) ; end lastStats = state.stats ; else spmd [net, state] = processEpoch(net, state, params, 'train') ; [net, state] = processEpoch(net, state, params, 'val') ; if labindex == 1 && ~evaluateMode saveState(modelPath(epoch), net, state) ; end lastStats = state.stats ; end lastStats = accumulateStats(lastStats) ; end stats.train(epoch) = lastStats.train ; stats.val(epoch) = lastStats.val ; clear lastStats ; if ~evaluateMode saveStats(modelPath(epoch), stats) ; end if params.plotStatistics switchFigure(1) ; clf ; plots = setdiff(... cat(2,... fieldnames(stats.train)', ... fieldnames(stats.val)'), {'num', 'time'}) ; for p = plots p = char(p) ; values = zeros(0, epoch) ; leg = {} ; for f = {'train', 'val'} f = char(f) ; if isfield(stats.(f), p) tmp = [stats.(f).(p)] ; values(end+1,:) = tmp(1,:)' ; leg{end+1} = f ; end end subplot(1,numel(plots),find(strcmp(p,plots))) ; plot(1:epoch, values','o-') ; xlabel('epoch') ; title(p) ; legend(leg{:}) ; grid on ; end drawnow ; print(1, modelFigPath, '-dpdf') ; end if ~isempty(opts.postEpochFn) if nargout(opts.postEpochFn) == 0 opts.postEpochFn(net, params, state) ; else lr = opts.postEpochFn(net, params, state) ; if ~isempty(lr), opts.learningRate = lr; end if opts.learningRate == 0, break; end end end end % Return the best performing model [~,best_epoch] = min(cat(1,stats.val.rmse)); fprintf('%s: Best model in epoch %d\n', mfilename, best_epoch) ; [net, state, stats] = loadState(modelPath(best_epoch)) ; % With multiple GPUs, return one copy if isa(net, 'Composite'), net = net{1} ; end % ------------------------------------------------------------------------- function [net, state] = processEpoch(net, state, params, mode) % ------------------------------------------------------------------------- % Note that net is not strictly needed as an output argument as net % is a handle class. However, this fixes some aliasing issue in the % spmd caller. % initialize with momentum 0 if isempty(state) || isempty(state.solverState) for i = 1:numel(net.layers) state.solverState{i} = cell(1, numel(net.layers{i}.weights)) ; state.solverState{i}(:) = {0} ; end end % move CNN to GPU as needed numGpus = numel(params.gpus) ; if numGpus >= 1 net = vl_simplenn_move(net, 'gpu') ; for i = 1:numel(state.solverState) for j = 1:numel(state.solverState{i}) s = state.solverState{i}{j} ; if isnumeric(s) state.solverState{i}{j} = gpuArray(s) ; elseif isstruct(s) state.solverState{i}{j} = structfun(@gpuArray, s, 'UniformOutput', false) ; end end end end if numGpus > 1 parserv = ParameterServer(params.parameterServer) ; vl_simplenn_start_parserv(net, parserv) ; else parserv = [] ; end % profile if params.profile if numGpus <= 1 profile clear ; profile on ; else mpiprofile reset ; mpiprofile on ; end end subset = params.(mode) ; num = 0 ; stats.num = 0 ; % return something even if subset = [] stats.time = 0 ; adjustTime = 0 ; res = [] ; preds_all = []; labels_all = []; err = 0; start = tic ; for t=1:params.batchSize:numel(subset) fprintf('%s: epoch %02d: %3d/%3d:', mode, params.epoch, ... fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize)) ; batchSize = min(params.batchSize, numel(subset) - t + 1) ; for s=1:params.numSubBatches % get this image batch and prefetch the next batchStart = t + (labindex-1) + (s-1) * numlabs ; batchEnd = min(t+params.batchSize-1, numel(subset)) ; batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ; num = num + numel(batch) ; if numel(batch) == 0, continue ; end [im, labels] = params.getBatch(params.imdb, batch) ; if params.prefetch if s == params.numSubBatches batchStart = t + (labindex-1) + params.batchSize ; batchEnd = min(t+2*params.batchSize-1, numel(subset)) ; else batchStart = batchStart + numlabs ; end nextBatch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ; params.getBatch(params.imdb, nextBatch) ; end if numGpus >= 1 im = gpuArray(im) ; end if strcmp(mode, 'train') dzdy = 1 ; evalMode = 'normal' ; else dzdy = [] ; evalMode = 'test' ; end net.layers{end}.class = labels ; res = vl_simplenn(net, im, dzdy, res, ... 'accumulate', s ~= 1, ... 'mode', evalMode, ... 'conserveMemory', params.conserveMemory, ... 'backPropDepth', params.backPropDepth, ... 'sync', params.sync, ... 'cudnn', params.cudnn, ... 'parameterServer', parserv, ... 'holdOn', s < params.numSubBatches) ; predictions = gather(res(end-1).x) ; [~,predictions] = sort(predictions, 3, 'descend') ; predictions = squeeze(predictions); num_bins = size(predictions,1); predictions = predictions(1,:); % Convert the class labels into the continuous values labels = unQuantizeContinuous(squeeze(labels), 0, 3, num_bins)'; predictions = unQuantizeContinuous(squeeze(predictions), 0, 3, num_bins)'; preds_all = cat(1, preds_all, predictions); labels_all = cat(1, labels_all, labels); err = [err(1)+sum(double(gather(res(end).x)));... corr(labels_all, preds_all);... sqrt(mean((labels_all-preds_all).^2))]; end % accumulate gradient if strcmp(mode, 'train') if ~isempty(parserv), parserv.sync() ; end [net, res, state] = accumulateGradients(net, res, state, params, batchSize, parserv) ; end % get statistics time = toc(start) + adjustTime ; batchTime = time - stats.time ; stats = extractStats(net, params, [err(1)/num; err(2); err(3)]) ; stats.num = num ; stats.time = time ; currentSpeed = batchSize / batchTime ; averageSpeed = (t + batchSize - 1) / time ; if t == 3*params.batchSize + 1 % compensate for the first three iterations, which are outliers adjustTime = 4*batchTime - time ; stats.time = time + adjustTime ; end fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ; for f = setdiff(fieldnames(stats)', {'num', 'time'}) f = char(f) ; fprintf(' %s: %.3f', f, stats.(f)) ; end fprintf('\n') ; % collect diagnostic statistics if strcmp(mode, 'train') && params.plotDiagnostics switchFigure(2) ; clf ; diagn = [res.stats] ; diagnvar = horzcat(diagn.variation) ; diagnpow = horzcat(diagn.power) ; subplot(2,2,1) ; barh(diagnvar) ; set(gca,'TickLabelInterpreter', 'none', ... 'YTick', 1:numel(diagnvar), ... 'YTickLabel',horzcat(diagn.label), ... 'YDir', 'reverse', ... 'XScale', 'log', ... 'XLim', [1e-5 1], ... 'XTick', 10.^(-5:1)) ; grid on ; title('Variation'); subplot(2,2,2) ; barh(sqrt(diagnpow)) ; set(gca,'TickLabelInterpreter', 'none', ... 'YTick', 1:numel(diagnpow), ... 'YTickLabel',{diagn.powerLabel}, ... 'YDir', 'reverse', ... 'XScale', 'log', ... 'XLim', [1e-5 1e5], ... 'XTick', 10.^(-5:5)) ; grid on ; title('Power'); subplot(2,2,3); plot(squeeze(res(end-1).x)) ; drawnow ; end end % Save back to state. state.stats.(mode) = stats ; if params.profile if numGpus <= 1 state.prof.(mode) = profile('info') ; profile off ; else state.prof.(mode) = mpiprofile('info'); mpiprofile off ; end end if ~params.saveSolverState state.solverState = [] ; else for i = 1:numel(state.solverState) for j = 1:numel(state.solverState{i}) s = state.solverState{i}{j} ; if isnumeric(s) state.solverState{i}{j} = gather(s) ; elseif isstruct(s) state.solverState{i}{j} = structfun(@gather, s, 'UniformOutput', false) ; end end end end net = vl_simplenn_move(net, 'cpu') ; % ------------------------------------------------------------------------- function [net, res, state] = accumulateGradients(net, res, state, params, batchSize, parserv) % ------------------------------------------------------------------------- numGpus = numel(params.gpus) ; otherGpus = setdiff(1:numGpus, labindex) ; for l=numel(net.layers):-1:1 for j=numel(res(l).dzdw):-1:1 if ~isempty(parserv) tag = sprintf('l%d_%d',l,j) ; parDer = parserv.pull(tag) ; else parDer = res(l).dzdw{j} ; end if j == 3 && strcmp(net.layers{l}.type, 'bnorm') % special case for learning bnorm moments thisLR = net.layers{l}.learningRate(j) ; net.layers{l}.weights{j} = vl_taccum(... 1 - thisLR, ... net.layers{l}.weights{j}, ... thisLR / batchSize, ... parDer) ; else % Standard gradient training. thisDecay = params.weightDecay * net.layers{l}.weightDecay(j) ; thisLR = params.learningRate * net.layers{l}.learningRate(j) ; if thisLR>0 || thisDecay>0 % Normalize gradient and incorporate weight decay. parDer = vl_taccum(1/batchSize, parDer, ... thisDecay, net.layers{l}.weights{j}) ; if isempty(params.solver) % Default solver is the optimised SGD. % Update momentum. state.solverState{l}{j} = vl_taccum(... params.momentum, state.solverState{l}{j}, ... -1, parDer) ; % Nesterov update (aka one step ahead). if params.nesterovUpdate delta = params.momentum * state.solverState{l}{j} - parDer ; else delta = state.solverState{l}{j} ; end % Update parameters. net.layers{l}.weights{j} = vl_taccum(... 1, net.layers{l}.weights{j}, ... thisLR, delta) ; else % call solver function to update weights [net.layers{l}.weights{j}, state.solverState{l}{j}] = ... params.solver(net.layers{l}.weights{j}, state.solverState{l}{j}, ... parDer, params.solverOpts, thisLR) ; end end end % if requested, collect some useful stats for debugging if params.plotDiagnostics variation = [] ; label = '' ; switch net.layers{l}.type case {'conv','convt'} if isnumeric(state.solverState{l}{j}) variation = thisLR * mean(abs(state.solverState{l}{j}(:))) ; end power = mean(res(l+1).x(:).^2) ; if j == 1 % fiters base = mean(net.layers{l}.weights{j}(:).^2) ; label = 'filters' ; else % biases base = sqrt(power) ;%mean(abs(res(l+1).x(:))) ; label = 'biases' ; end variation = variation / base ; label = sprintf('%s_%s', net.layers{l}.name, label) ; end res(l).stats.variation(j) = variation ; res(l).stats.power = power ; res(l).stats.powerLabel = net.layers{l}.name ; res(l).stats.label{j} = label ; end end end % ------------------------------------------------------------------------- function stats = accumulateStats(stats_) % ------------------------------------------------------------------------- for s = {'train', 'val'} s = char(s) ; total = 0 ; % initialize stats stucture with same fields and same order as % stats_{1} stats__ = stats_{1} ; names = fieldnames(stats__.(s))' ; values = zeros(1, numel(names)) ; fields = cat(1, names, num2cell(values)) ; stats.(s) = struct(fields{:}) ; for g = 1:numel(stats_) stats__ = stats_{g} ; num__ = stats__.(s).num ; total = total + num__ ; for f = setdiff(fieldnames(stats__.(s))', 'num') f = char(f) ; stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ; if g == numel(stats_) stats.(s).(f) = stats.(s).(f) / total ; end end end stats.(s).num = total ; end % ------------------------------------------------------------------------- function stats = extractStats(net, params, errors) % ------------------------------------------------------------------------- stats.objective = errors(1) ; for i = 1:numel(params.errorLabels) stats.(params.errorLabels{i}) = errors(i+1) ; end % ------------------------------------------------------------------------- function saveState(fileName, net, state) % ------------------------------------------------------------------------- save(fileName, 'net', 'state') ; % ------------------------------------------------------------------------- function saveStats(fileName, stats) % ------------------------------------------------------------------------- if exist(fileName) save(fileName, 'stats', '-append') ; else save(fileName, 'stats') ; end % ------------------------------------------------------------------------- function [net, state, stats] = loadState(fileName) % ------------------------------------------------------------------------- load(fileName, 'net', 'state', 'stats') ; net = vl_simplenn_tidy(net) ; if isempty(whos('stats')) error('Epoch ''%s'' was only partially saved. Delete this file and try again.', ... fileName) ; end % ------------------------------------------------------------------------- function epoch = findLastCheckpoint(modelDir) % ------------------------------------------------------------------------- list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ; tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ; epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ; epoch = max([epoch 0]) ; % ------------------------------------------------------------------------- function switchFigure(n) % ------------------------------------------------------------------------- if get(0,'CurrentFigure') ~= n try set(0,'CurrentFigure',n) ; catch figure(n) ; end end % ------------------------------------------------------------------------- function clearMex() % ------------------------------------------------------------------------- %clear vl_tmove vl_imreadjpeg ; disp('Clearing mex files') ; clear mex ; clear vl_tmove vl_imreadjpeg ; % ------------------------------------------------------------------------- function prepareGPUs(params, cold) % ------------------------------------------------------------------------- numGpus = numel(params.gpus) ; if numGpus > 1 % check parallel pool integrity as it could have timed out pool = gcp('nocreate') ; if ~isempty(pool) && pool.NumWorkers ~= numGpus delete(pool) ; end pool = gcp('nocreate') ; if isempty(pool) parpool('local', numGpus) ; cold = true ; end end if numGpus >= 1 && cold fprintf('%s: resetting GPU\n', mfilename) ; clearMex() ; if numGpus == 1 disp(gpuDevice(params.gpus)) ; else spmd clearMex() ; disp(gpuDevice(params.gpus(labindex))) ; end end end