29 lines
961 B
Matlab
29 lines
961 B
Matlab
function net = cnntrain(net, x, y, opts)
|
|
m = size(x, 3);
|
|
numbatches = floor(m / opts.batchsize);
|
|
if rem(numbatches, 1) ~= 0
|
|
error('numbatches not integer');
|
|
end
|
|
net.rL = [];
|
|
for i = 1 : opts.numepochs
|
|
net = cnnff(net, x);
|
|
error_curr = sqrt(mean((net.o - y).^2));
|
|
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs), ' RMSE-', num2str(error_curr)]);
|
|
tic;
|
|
kk = randperm(m);
|
|
for l = 1 : numbatches
|
|
batch_x = x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
|
|
batch_y = y(:, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
|
|
|
|
net = cnnff(net, batch_x);
|
|
net = cnnbp(net, batch_y);
|
|
net = cnnapplygrads(net, opts);
|
|
if isempty(net.rL)
|
|
net.rL(1) = net.L;
|
|
end
|
|
net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L;
|
|
end
|
|
toc;
|
|
end
|
|
|
|
end
|