這一篇文章是上一篇文章的代碼分析。
1.主函數(shù)
function [net, info] = cnn_dicnn(varargin)
% 預(yù)先setup毁渗,把各個(gè)子文件夾都加入到路徑中能真。
%生成當(dāng)前文件所在的完整目錄嘶摊,包括文件名-------mfilename('fullpath')
%文件完整目錄分割成目錄顽决、文件名和后綴-------[pathstr,name,ext]= fileparts(filename)
run(fullfile(fileparts(mfilename('fullpath')),'matconvnet', 'matlab', 'vl_setupnn.m')) ;
% 讀入文件夾的路徑
opts.dataDir = fullfile('data','image') ;
opts.expDir = fullfile('exp', 'image') ;
% 讀入預(yù)訓(xùn)練的model的路徑
opts.modelPath = fullfile('models','imagenet-vgg-f.mat');
%將輸入變量的par-val參數(shù)對(duì)加到opts結(jié)構(gòu)體中
[opts, varargin] = vl_argparse(opts, varargin) ;
opts.numFetchThreads = 12 ; ??????????
opts.lite = false ; ?????????
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat'); %imdb數(shù)據(jù)的地址
%對(duì)訓(xùn)練的參數(shù)加一個(gè)參數(shù)結(jié)構(gòu)體
opts.train = struct() ;
opts.train.gpus = [1]; %是否使用GPU
opts.train.batchSize = 8 ; %batch大小
opts.train.numSubBatches = 4 ;
opts.train.learningRate = 1e-4 * [ones(1,10), 0.1*ones(1,5)]; %學(xué)習(xí)率
opts = vl_argparse(opts, varargin) ;
if ~isfield(opts.train, 'gpus'), opts.train.gpus = []; end
% -------------------------------------------------------------------------
% Prepare model
% -------------------------------------------------------------------------
net = load(opts.modelPath);
% 修改一下這個(gè)model族铆,進(jìn)入函數(shù)2
net = prepareDINet(net,opts);
% -------------------------------------------------------------------------
% Prepare data
% -------------------------------------------------------------------------
% 準(zhǔn)備數(shù)據(jù)格式
if exist(opts.imdbPath,'file')
imdb = load(opts.imdbPath) ;
else
imdb = cnn_image_setup_data('dataDir', opts.dataDir, 'lite', opts.lite) ; %進(jìn)入函數(shù)3
mkdir(opts.expDir) ; %創(chuàng)建文件夾exp/image
save(opts.imdbPath, '-struct', 'imdb') ; %保存結(jié)構(gòu)體将硝,https://ww2.mathworks.cn/help/matlab/ref/save.html
end
imdb.images.set = imdb.images.sets;
%把原網(wǎng)絡(luò)的類別(1000類)描述換成自己的描述(10類)
net.meta.classes.name = imdb.classes.name ;
net.meta.classes.description = imdb.classes.name ;
% % 求訓(xùn)練集的均值恭朗,進(jìn)入函數(shù)4
imageStatsPath = fullfile(opts.expDir, 'imageStats.mat') ;
if exist(imageStatsPath)
load(imageStatsPath, 'averageImage') ;
else
averageImage = getImageStats(opts, net.meta, imdb) ;
save(imageStatsPath, 'averageImage') ;
end
% % 用新的均值改變均值
net.meta.normalization.averageImage = averageImage;
% -------------------------------------------------------------------------
% Learn
% -------------------------------------------------------------------------
% 索引訓(xùn)練集==1 和測(cè)試集==3
opts.train.train = find(imdb.images.set==1) ;
opts.train.val = find(imdb.images.set==3) ;
% 訓(xùn)練
[net, info] = cnn_train_dag(net, imdb, getBatchFn(opts, net.meta), ...
'expDir', opts.expDir, ...
opts.train) ;
% -------------------------------------------------------------------------
% Deploy
% -------------------------------------------------------------------------
% 保存訓(xùn)練完的網(wǎng)絡(luò)
%net = cnn_imagenet_deploy(net) ;
net = cnn_imagenet_deploy(net);
modelPath = fullfile(opts.expDir, 'net-deployed.mat');
net_ = net.saveobj() ;
save(modelPath, '-struct', 'net_') ;
clear net_ ;
2.model預(yù)調(diào)整
% -------------------------------------------------------------------------
function net = prepareDINet(net,opts)
% -------------------------------------------------------------------------
% 把 fc8層換成fc8l(原理還是有點(diǎn)搞不明白)
%對(duì)元胞數(shù)組中的每個(gè)元胞應(yīng)用函數(shù),官方幫助文件:https://ww2.mathworks.cn/help/matlab/ref/cellfun.html
fc8l = cellfun(@(a) strcmp(a.name, 'fc8'), net.layers)==1;
%% note: 下面這個(gè)是類別數(shù),一定要和自己的類別數(shù)吻合(這里為10類)
nCls = 10;
sizeW = size(net.layers{fc8l}.weights{1});
%如果所需類別數(shù)和原網(wǎng)絡(luò)類別不一樣依疼。則用0初始化權(quán)重參數(shù)
if sizeW(4)~=nCls
net.layers{fc8l}.weights = {zeros(sizeW(1),sizeW(2),sizeW(3),nCls,'single'), ...
zeros(1, nCls, 'single')};
end
% change loss 添加一個(gè)loss層用于訓(xùn)練
net.layers{end} = struct('name','loss', 'type','softmaxloss') ;
%將普通nn轉(zhuǎn)化成dagnn,比較靈活痰腮,參考鏈接:https://www.cnblogs.com/ironstark/p/6058090.html
net = dagnn.DagNN.fromSimpleNN(net, 'canonicalNames', true) ;
%添加error層
net.addLayer('top1err', dagnn.Loss('loss', 'classerror'), ...
{'prediction','label'}, 'top1err') ;
net.addLayer('top5err', dagnn.Loss('loss', 'topkerror', ...
'opts', {'topK',5}), ...
{'prediction','label'}, 'top5err') ;
3.數(shù)據(jù)集預(yù)處理
function imdb = cnn_image_setup_data(varargin)
opts.dataDir = fullfile('data','image') ;
opts.lite = false ;
opts = vl_argparse(opts, varargin) ;
% ------------------------------------------------------------------------
% Load categories metadata
% -------------------------------------------------------------------------
metaPath = fullfile(opts.dataDir, 'classInd.txt') ;
fprintf('using metadata %s\n', metaPath) ;
tmp = importdata(metaPath);
nCls = numel(tmp);
% 判斷類別與設(shè)定的是否一樣 10為樣本的類別總數(shù)(自己的數(shù)據(jù)集需要修改)
if nCls ~= 10
error('Wrong meta file %s',metaPath);
end
% 將名字分離出來
cats = cell(1,nCls);
for i=1:numel(tmp)
t = strsplit(tmp{i});
cats{i} = t{2};
end
% 數(shù)據(jù)集文件夾選擇
imdb.classes.name = cats ; %類別名稱
imdb.imageDir.train = fullfile(opts.dataDir, 'train') ; %訓(xùn)練數(shù)據(jù)地址
imdb.imageDir.test = fullfile(opts.dataDir, 'test') ; %測(cè)試數(shù)據(jù)地址
%% -----------------------------------------------------------------
% load image names and labels
% -------------------------------------------------------------------------
name = {};
labels = {} ;
imdb.images.sets = [] ;
%%
fprintf('searching training images ...\n') ;
% 導(dǎo)入訓(xùn)練類別標(biāo)簽
train_label_path = fullfile(opts.dataDir, 'train_label.txt') ;
train_label_temp = importdata(train_label_path);
temp_l = train_label_temp.data;
for i=1:numel(temp_l)
train_label{i} = temp_l(i);
end
if length(train_label) ~= length(dir(fullfile(imdb.imageDir.train, '*.jpg')))
error('training data is not equal to its label!!!');
end
i = 1;
for d = dir(fullfile(imdb.imageDir.train, '*.jpg'))'
name{end+1} = d.name;
labels{end+1} = train_label{i} ;
if mod(numel(name), 10) == 0, fprintf('.') ; end
if mod(numel(name), 500) == 0, fprintf('\n') ; end
imdb.images.sets(end+1) = 1;%train
i = i+1;
end
%%
fprintf('searching testing images ...\n') ;
% 導(dǎo)入測(cè)試類別標(biāo)簽
test_label_path = fullfile(opts.dataDir, 'test_label.txt') ;
test_label_temp = importdata(test_label_path);
temp_l = test_label_temp.data;
for i=1:numel(temp_l)
test_label{i} = temp_l(i);
end
if length(test_label) ~= length(dir(fullfile(imdb.imageDir.test, '*.jpg')))
error('testing data is not equal to its label!!!');
end
i = 1;
for d = dir(fullfile(imdb.imageDir.test, '*.jpg'))'
name{end+1} = d.name;
labels{end+1} = test_label{i} ;
if mod(numel(name), 10) == 0, fprintf('.') ; end
if mod(numel(name), 500) == 0, fprintf('\n') ; end
imdb.images.sets(end+1) = 3;%test
i = i+1;
end
labels = horzcat(labels{:}) ; %horzcat水平串聯(lián)數(shù)組
imdb.images.id = 1:numel(name) ; %給圖像編號(hào)
imdb.images.name = name ; %圖像文件名
imdb.images.label = labels ; %圖像標(biāo)簽
輸出的imdb數(shù)據(jù)結(jié)構(gòu)如下圖所示。
4.求樣本均值
% 求訓(xùn)練樣本的均值
% -------------------------------------------------------------------------
function averageImage = getImageStats(opts, meta, imdb)
% -------------------------------------------------------------------------
train = find(imdb.images.set == 1) ; %找出識(shí)別號(hào)為1的(代表訓(xùn)練集)
batch = 1:length(train);
fn = getBatchFn(opts, meta) ; %步入函數(shù)5
train = train(1: 100: end); %按照100個(gè)為一個(gè)batch
avg = {};
for i = 1:length(train)
temp = fn(imdb, batch(train(i):train(i)+99)) ; %temp為圖像+標(biāo)簽的序列
temp = temp{2}; %只為圖像 224x224x3x100
avg{end+1} = mean(temp, 4) ; %將其按照第四維求平均,得到10個(gè)平均圖像
end
averageImage = mean(cat(4,avg{:}),4) ; %再將這十個(gè)圖像求平均
% 將GPU格式的轉(zhuǎn)化為cpu格式的保存起來(如果有用GPU)
averageImage = gather(averageImage);
5.定義Fn函數(shù)
function fn = getBatchFn(opts, meta)
% -------------------------------------------------------------------------
useGpu = numel(opts.train.gpus) > 0 ; %是否使用GPU
bopts.numThreads = opts.numFetchThreads ; %12
bopts.imageSize = meta.normalization.imageSize ; %[224,224,3,10]
bopts.border = meta.normalization.border ; %[32,32]
% bopts.averageImage = [];
bopts.averageImage = meta.normalization.averageImage ; %224x224x3 double
% bopts.rgbVariance = meta.augmentation.rgbVariance ;
% bopts.transformation = meta.augmentation.transformation ;
fn = @(x,y) getDagNNBatch(bopts,useGpu,x,y) ; %定義function:fn=@(x,y)getDagNNBatch(bopts,useGpu,x,y)
6.input = {'input', im, 'label', labels}
產(chǎn)生input cell律罢,input:224x224x3x100, labels:1x100膀值。將100個(gè)圖像以及標(biāo)簽打包
function inputs = getDagNNBatch(opts, useGpu, imdb, batch)
% -------------------------------------------------------------------------
% 判斷讀入數(shù)據(jù)為訓(xùn)練還是測(cè)試
for i = 1:length(batch)
if imdb.images.set(batch(i)) == 1 %1為訓(xùn)練索引文件夾
images(i) = strcat([imdb.imageDir.train filesep] , imdb.images.name(batch(i))); %filesep為文件分割符,strcat為橫向連接字符串
else
images(i) = strcat([imdb.imageDir.test filesep] , imdb.images.name(batch(i)));
end
end
isVal = ~isempty(batch) && imdb.images.set(batch(1)) ~= 1 ;
if ~isVal
% training
im = cnn_imagenet_get_batch(images, opts, ...
'prefetch', nargout == 0) ; %步入函數(shù)7
else
% validation: disable data augmentation
im = cnn_imagenet_get_batch(images, opts, ...
'prefetch', nargout == 0, ...
'transformation', 'none') ;
end
if nargout > 0
if useGpu
im = gpuArray(im) ;
end
labels = imdb.images.label(batch) ;
inputs = {'input', im, 'label', labels} ;
end
7.打包圖像误辑,預(yù)處理(到規(guī)定網(wǎng)絡(luò)大小和平均值)
function imo = cnn_imagenet_get_batch(images, varargin)
% CNN_IMAGENET_GET_BATCH Load, preprocess, and pack images for CNN evaluation
opts.imageSize = [227, 227] ;
opts.border = [29, 29] ;
opts.keepAspect = true ;
opts.numAugments = 1 ;
opts.transformation = 'none' ;
opts.averageImage = [] ;
opts.rgbVariance = zeros(0,3,'single') ;
opts.interpolation = 'bilinear' ;
opts.numThreads = 1 ;
opts.prefetch = false ;
opts = vl_argparse(opts, varargin);
% fetch is true if images is a list of filenames (instead of
% a cell array of images)
fetch = numel(images) >= 1 && ischar(images{1}) ;
% prefetch is used to load images in a separate thread
prefetch = fetch & opts.prefetch ;
if prefetch
vl_imreadjpeg(images, 'numThreads', opts.numThreads, 'prefetch') ;
imo = [] ;
return ;
end
if fetch
im = vl_imreadjpeg(images,'numThreads', opts.numThreads) ; %批量讀取圖像文件
else
im = images ;
end
tfs = [] ; %定義變換
switch opts.transformation
case 'none'
tfs = [
.5 ;
.5 ;
0 ] ;
case 'f5'
tfs = [...
.5 0 0 1 1 .5 0 0 1 1 ;
.5 0 1 0 1 .5 0 1 0 1 ;
0 0 0 0 0 1 1 1 1 1] ;
case 'f25'
[tx,ty] = meshgrid(linspace(0,1,5)) ;
tfs = [tx(:)' ; ty(:)' ; zeros(1,numel(tx))] ;
tfs_ = tfs ;
tfs_(3,:) = 1 ;
tfs = [tfs,tfs_] ;
case 'stretch'
otherwise
error('Uknown transformations %s', opts.transformation) ;
end
[~,transformations] = sort(rand(size(tfs,2), numel(images)), 1) ;
if ~isempty(opts.rgbVariance) && isempty(opts.averageImage)
opts.averageImage = zeros(1,1,3) ;
end
if numel(opts.averageImage) == 3
opts.averageImage = reshape(opts.averageImage, 1,1,3) ;
end
imo = zeros(opts.imageSize(1), opts.imageSize(2), 3, ...
numel(images)*opts.numAugments, 'single') ;
si = 1 ;
for i=1:numel(images)
% acquire image
if isempty(im{i})
imt = imread(images{i}) ;
imt = single(imt) ; % faster than im2single (and multiplies by 255)
else
imt = im{i} ;
end
if size(imt,3) == 1
imt = cat(3, imt, imt, imt) ;
end
% resize
w = size(imt,2) ;
h = size(imt,1) ;
factor = [(opts.imageSize(1)+opts.border(1))/h ...
(opts.imageSize(2)+opts.border(2))/w];
if opts.keepAspect
factor = max(factor) ;
end
if any(abs(factor - 1) > 0.0001) %只要形變因子不為1沧踏,則resize圖像
imt = imresize(imt, ...
'scale', factor, ...
'method', opts.interpolation) ;
end
% crop & flip
w = size(imt,2) ;
h = size(imt,1) ;
for ai = 1:opts.numAugments
switch opts.transformation
case 'stretch'
sz = round(min(opts.imageSize(1:2)' .* (1-0.1+0.2*rand(2,1)), [h;w])) ;
dx = randi(w - sz(2) + 1, 1) ;
dy = randi(h - sz(1) + 1, 1) ;
flip = rand > 0.5 ;
otherwise
tf = tfs(:, transformations(mod(ai-1, numel(transformations)) + 1)) ;
sz = opts.imageSize(1:2) ;
dx = floor((w - sz(2)) * tf(2)) + 1 ;
dy = floor((h - sz(1)) * tf(1)) + 1 ;
flip = tf(3) ;
end
sx = round(linspace(dx, sz(2)+dx-1, opts.imageSize(2))) ;
sy = round(linspace(dy, sz(1)+dy-1, opts.imageSize(1))) ;
if flip, sx = fliplr(sx) ; end
if ~isempty(opts.averageImage)
offset = opts.averageImage ;
if ~isempty(opts.rgbVariance)
offset = bsxfun(@plus, offset, reshape(opts.rgbVariance * randn(3,1), 1,1,3)) ;
end
imo(:,:,:,si) = bsxfun(@minus, imt(sy,sx,:), offset) ;
else
imo(:,:,:,si) = imt(sy,sx,:) ;
end
si = si + 1 ;
end
end
8.訓(xùn)練函數(shù)
function [net,stats] = cnn_train_dag(net, imdb, getBatch, varargin)
%CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper
% CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with
% the DagNN wrapper instead of the SimpleNN wrapper.
% Copyright (C) 2014-16 Andrea Vedaldi.
% All rights reserved.
%
% This file is part of the VLFeat library and is made available under
% the terms of the BSD license (see the COPYING file).
addpath(fullfile(vl_rootnn, 'examples'));
opts.expDir = fullfile('data','exp') ;
opts.continue = true ;
opts.batchSize = 256 ;
opts.numSubBatches = 1 ;
opts.train = [] ;
opts.val = [] ;
opts.gpus = [1] ;
opts.prefetch = false ;
opts.epochSize = inf;
opts.numEpochs = 20 ;
opts.learningRate = 0.001 ;
opts.weightDecay = 0.0005 ;
opts.solver = [] ; % 空集代表使用 SGD優(yōu)化方法
[opts, varargin] = vl_argparse(opts, varargin) ;
if ~isempty(opts.solver)
assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,...
'Invalid solver; expected a function handle with two outputs.') ;
% Call without input arguments, to get default options
opts.solverOpts = opts.solver() ;
end
opts.momentum = 0.9 ;
opts.saveSolverState = true ;
opts.nesterovUpdate = false ;
opts.randomSeed = 0 ;
opts.profile = false ;
opts.parameterServer.method = 'mmap' ;
opts.parameterServer.prefix = 'mcn' ;
opts.derOutputs = {'objective', 1} ;
opts.extractStatsFn = @extractStats ;
opts.plotStatistics = true;
opts.postEpochFn = [] ; % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change
opts = vl_argparse(opts, varargin) ;
if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end
if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train)
opts.train = [] ;
end
if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val)
opts.val = [] ;
end
% -------------------------------------------------------------------------
% Initialization
% -------------------------------------------------------------------------
evaluateMode = isempty(opts.train) ;
if ~evaluateMode
if isempty(opts.derOutputs)
error('DEROUTPUTS must be specified when training.\n') ;
end
end
% -------------------------------------------------------------------------
% Train and validate
% -------------------------------------------------------------------------
modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep));
modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ;
start = opts.continue * findLastCheckpoint(opts.expDir) ; %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%卡在這里
if start >= 1
fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
[net, state, stats] = loadState(modelPath(start)) ;
else
state = [] ;
end
for epoch=start+1:opts.numEpochs
% Set the random seed based on the epoch and opts.randomSeed.
% This is important for reproducibility, including when training
% is restarted from a checkpoint.
rng(epoch + opts.randomSeed) ;
prepareGPUs(opts, epoch == start+1) ;
% Train for one epoch.
params = opts ;
params.epoch = epoch ;
params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
params.train = opts.train(randperm(numel(opts.train))) ; % shuffle
params.train = params.train(1:min(opts.epochSize, numel(opts.train)));
params.val = opts.val(randperm(numel(opts.val))) ;
params.imdb = imdb ;
params.getBatch = getBatch ;
if numel(opts.gpus) <= 1
[net, state] = processEpoch(net, state, params, 'train') ;
[net, state] = processEpoch(net, state, params, 'val') ;
if ~evaluateMode
saveState(modelPath(epoch), net, state) ;
end
lastStats = state.stats ;
else
spmd
[net, state] = processEpoch(net, state, params, 'train') ;
[net, state] = processEpoch(net, state, params, 'val') ;
if labindex == 1 && ~evaluateMode
saveState(modelPath(epoch), net, state) ;
end
lastStats = state.stats ;
end
lastStats = accumulateStats(lastStats) ;
end
stats.train(epoch) = lastStats.train ;
stats.val(epoch) = lastStats.val ;
clear lastStats ;
saveStats(modelPath(epoch), stats) ;
if opts.plotStatistics
switchFigure(1) ; clf ;
plots = setdiff(...
cat(2,...
fieldnames(stats.train)', ...
fieldnames(stats.val)'), {'num', 'time'}) ;
for p = plots
p = char(p) ;
values = zeros(0, epoch) ;
leg = {} ;
for f = {'train', 'val'}
f = char(f) ;
if isfield(stats.(f), p)
tmp = [stats.(f).(p)] ;
values(end+1,:) = tmp(1,:)' ;
leg{end+1} = f ;
end
end
subplot(1,numel(plots),find(strcmp(p,plots))) ;
plot(1:epoch, values','o-') ;
xlabel('epoch') ;
title(p) ;
legend(leg{:}) ;
grid on ;
end
drawnow ;
print(1, modelFigPath, '-dpdf') ;
end
if ~isempty(opts.postEpochFn)
if nargout(opts.postEpochFn) == 0
opts.postEpochFn(net, params, state) ;
else
lr = opts.postEpochFn(net, params, state) ;
if ~isempty(lr), opts.learningRate = lr; end
if opts.learningRate == 0, break; end
end
end
end
% With multiple GPUs, return one copy
if isa(net, 'Composite'), net = net{1} ; end
% -------------------------------------------------------------------------
function [net, state] = processEpoch(net, state, params, mode)
% -------------------------------------------------------------------------
% Note that net is not strictly needed as an output argument as net
% is a handle class. However, this fixes some aliasing issue in the
% spmd caller.
% initialize with momentum 0
if isempty(state) || isempty(state.solverState)
state.solverState = cell(1, numel(net.params)) ;
state.solverState(:) = {0} ;
end
% move CNN to GPU as needed
numGpus = numel(params.gpus) ;
if numGpus >= 1
net.move('gpu') ;
for i = 1:numel(state.solverState)
s = state.solverState{i} ;
if isnumeric(s)
state.solverState{i} = gpuArray(s) ;
elseif isstruct(s)
state.solverState{i} = structfun(@gpuArray, s, 'UniformOutput', false) ;
end
end
end
if numGpus > 1
parserv = ParameterServer(params.parameterServer) ;
net.setParameterServer(parserv) ;
else
parserv = [] ;
end
% profile
if params.profile
if numGpus <= 1
profile clear ;
profile on ;
else
mpiprofile reset ;
mpiprofile on ;
end
end
num = 0 ;
epoch = params.epoch ;
subset = params.(mode) ;
adjustTime = 0 ;
stats.num = 0 ; % return something even if subset = []
stats.time = 0 ;
start = tic ;
for t=1:params.batchSize:numel(subset)
fprintf('%s: epoch %02d: %3d/%3d:', mode, epoch, ...
fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize)) ;
batchSize = min(params.batchSize, numel(subset) - t + 1) ;
for s=1:params.numSubBatches
% get this image batch and prefetch the next
batchStart = t + (labindex-1) + (s-1) * numlabs ;
batchEnd = min(t+params.batchSize-1, numel(subset)) ;
batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
num = num + numel(batch) ;
if numel(batch) == 0, continue ; end
inputs = params.getBatch(params.imdb, batch) ;
if params.prefetch
if s == params.numSubBatches
batchStart = t + (labindex-1) + params.batchSize ;
batchEnd = min(t+2*params.batchSize-1, numel(subset)) ;
else
batchStart = batchStart + numlabs ;
end
nextBatch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
params.getBatch(params.imdb, nextBatch) ;
end
if strcmp(mode, 'train')
net.mode = 'normal' ;
net.accumulateParamDers = (s ~= 1) ;
net.eval(inputs, params.derOutputs, 'holdOn', s < params.numSubBatches) ;
else
net.mode = 'test' ;
net.eval(inputs) ;
end
end
% Accumulate gradient.
if strcmp(mode, 'train')
if ~isempty(parserv), parserv.sync() ; end
state = accumulateGradients(net, state, params, batchSize, parserv) ;
end
% Get statistics.
time = toc(start) + adjustTime ;
batchTime = time - stats.time ;
stats.num = num ;
stats.time = time ;
stats = params.extractStatsFn(stats,net) ;
currentSpeed = batchSize / batchTime ;
averageSpeed = (t + batchSize - 1) / time ;
if t == 3*params.batchSize + 1
% compensate for the first three iterations, which are outliers
adjustTime = 4*batchTime - time ;
stats.time = time + adjustTime ;
end
fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ;
for f = setdiff(fieldnames(stats)', {'num', 'time'})
f = char(f) ;
fprintf(' %s: %.3f', f, stats.(f)) ;
end
fprintf('\n') ;
end
% Save back to state.
state.stats.(mode) = stats ;
if params.profile
if numGpus <= 1
state.prof.(mode) = profile('info') ;
profile off ;
else
state.prof.(mode) = mpiprofile('info');
mpiprofile off ;
end
end
if ~params.saveSolverState
state.solverState = [] ;
else
for i = 1:numel(state.solverState)
s = state.solverState{i} ;
if isnumeric(s)
state.solverState{i} = gather(s) ;
elseif isstruct(s)
state.solverState{i} = structfun(@gather, s, 'UniformOutput', false) ;
end
end
end
net.reset() ;
net.move('cpu') ;
% -------------------------------------------------------------------------
function state = accumulateGradients(net, state, params, batchSize, parserv)
% -------------------------------------------------------------------------
numGpus = numel(params.gpus) ;
otherGpus = setdiff(1:numGpus, labindex) ;
for p=1:numel(net.params)
if ~isempty(parserv)
parDer = parserv.pullWithIndex(p) ;
else
parDer = net.params(p).der ;
end
switch net.params(p).trainMethod
case 'average' % mainly for batch normalization
thisLR = net.params(p).learningRate ;
net.params(p).value = vl_taccum(...
1 - thisLR, net.params(p).value, ...
(thisLR/batchSize/net.params(p).fanout), parDer) ;
case 'gradient'
thisDecay = params.weightDecay * net.params(p).weightDecay ;
thisLR = params.learningRate * net.params(p).learningRate ;
if thisLR>0 || thisDecay>0
% Normalize gradient and incorporate weight decay.
parDer = vl_taccum(1/batchSize, parDer, ...
thisDecay, net.params(p).value) ;
if isempty(params.solver)
% Default solver is the optimised SGD.
% Update momentum.
state.solverState{p} = vl_taccum(...
params.momentum, state.solverState{p}, ...
-1, parDer) ;
% Nesterov update (aka one step ahead).
if params.nesterovUpdate
delta = params.momentum * state.solverState{p} - parDer ;
else
delta = state.solverState{p} ;
end
% Update parameters.
net.params(p).value = vl_taccum(...
1, net.params(p).value, thisLR, delta) ;
else
% call solver function to update weights
[net.params(p).value, state.solverState{p}] = ...
params.solver(net.params(p).value, state.solverState{p}, ...
parDer, params.solverOpts, thisLR) ;
end
end
otherwise
error('Unknown training method ''%s'' for parameter ''%s''.', ...
net.params(p).trainMethod, ...
net.params(p).name) ;
end
end
% -------------------------------------------------------------------------
function stats = accumulateStats(stats_)
% -------------------------------------------------------------------------
for s = {'train', 'val'}
s = char(s) ;
total = 0 ;
% initialize stats stucture with same fields and same order as
% stats_{1}
stats__ = stats_{1} ;
names = fieldnames(stats__.(s))' ;
values = zeros(1, numel(names)) ;
fields = cat(1, names, num2cell(values)) ;
stats.(s) = struct(fields{:}) ;
for g = 1:numel(stats_)
stats__ = stats_{g} ;
num__ = stats__.(s).num ;
total = total + num__ ;
for f = setdiff(fieldnames(stats__.(s))', 'num')
f = char(f) ;
stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;
if g == numel(stats_)
stats.(s).(f) = stats.(s).(f) / total ;
end
end
end
stats.(s).num = total ;
end
% -------------------------------------------------------------------------
function stats = extractStats(stats, net)
% -------------------------------------------------------------------------
sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ;
for i = 1:numel(sel)
if net.layers(sel(i)).block.ignoreAverage, continue; end;
stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ;
end
% -------------------------------------------------------------------------
function saveState(fileName, net_, state)
% -------------------------------------------------------------------------
net = net_.saveobj() ;
save(fileName, 'net', 'state') ;
% -------------------------------------------------------------------------
function saveStats(fileName, stats)
% -------------------------------------------------------------------------
if exist(fileName)
save(fileName, 'stats', '-append') ;
else
save(fileName, 'stats') ;
end
% -------------------------------------------------------------------------
function [net, state, stats] = loadState(fileName)
% -------------------------------------------------------------------------
load(fileName, 'net', 'state', 'stats') ;
net = dagnn.DagNN.loadobj(net) ;
if isempty(whos('stats'))
error('Epoch ''%s'' was only partially saved. Delete this file and try again.', ...
fileName) ;
end
% -------------------------------------------------------------------------
function epoch = findLastCheckpoint(modelDir)
% -------------------------------------------------------------------------
list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
epoch = max([epoch 0]) ;
% -------------------------------------------------------------------------
function switchFigure(n)
% -------------------------------------------------------------------------
if get(0,'CurrentFigure') ~= n
try
set(0,'CurrentFigure',n) ;
catch
figure(n) ;
end
end
% -------------------------------------------------------------------------
function clearMex()
% -------------------------------------------------------------------------
clear vl_tmove vl_imreadjpeg ;
% -------------------------------------------------------------------------
function prepareGPUs(opts, cold)
% -------------------------------------------------------------------------
numGpus = numel(opts.gpus) ;
if numGpus > 1
% check parallel pool integrity as it could have timed out
pool = gcp('nocreate') ;
if ~isempty(pool) && pool.NumWorkers ~= numGpus
delete(pool) ;
end
pool = gcp('nocreate') ;
if isempty(pool)
parpool('local', numGpus) ;
cold = true ;
end
end
if numGpus >= 1 && cold
fprintf('%s: resetting GPU\n', mfilename)
clearMex() ;
if numGpus == 1
gpuDevice(opts.gpus)
else
spmd
clearMex() ;
gpuDevice(opts.gpus(labindex))
end
end
end