sustaining_gazes/matlab_version/face_validation/DeepLearnToolbox/CNN/cnntrain.m
2016-04-28 15:40:36 -04:00

29 lines
961 B
Matlab

function net = cnntrain(net, x, y, opts)
m = size(x, 3);
numbatches = floor(m / opts.batchsize);
if rem(numbatches, 1) ~= 0
error('numbatches not integer');
end
net.rL = [];
for i = 1 : opts.numepochs
net = cnnff(net, x);
error_curr = sqrt(mean((net.o - y).^2));
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs), ' RMSE-', num2str(error_curr)]);
tic;
kk = randperm(m);
for l = 1 : numbatches
batch_x = x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
batch_y = y(:, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
net = cnnff(net, batch_x);
net = cnnbp(net, batch_y);
net = cnnapplygrads(net, opts);
if isempty(net.rL)
net.rL(1) = net.L;
end
net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L;
end
toc;
end
end