2019-04-18 Matconvnet:matlab建立自己的數(shù)據(jù)集并訓(xùn)練(代碼分析備注)

這一篇文章是上一篇文章的代碼分析。

1.主函數(shù)

function [net, info] = cnn_dicnn(varargin)

% 預(yù)先setup毁渗,把各個(gè)子文件夾都加入到路徑中能真。
%生成當(dāng)前文件所在的完整目錄嘶摊,包括文件名-------mfilename('fullpath')
%文件完整目錄分割成目錄顽决、文件名和后綴-------[pathstr,name,ext]= fileparts(filename)

run(fullfile(fileparts(mfilename('fullpath')),'matconvnet', 'matlab', 'vl_setupnn.m')) ;

% 讀入文件夾的路徑
opts.dataDir = fullfile('data','image') ;
opts.expDir  = fullfile('exp', 'image') ;
% 讀入預(yù)訓(xùn)練的model的路徑
opts.modelPath = fullfile('models','imagenet-vgg-f.mat');

%將輸入變量的par-val參數(shù)對(duì)加到opts結(jié)構(gòu)體中
[opts, varargin] = vl_argparse(opts, varargin) ;

opts.numFetchThreads = 12 ;         ??????????

opts.lite = false ;                               ?????????
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat');       %imdb數(shù)據(jù)的地址

%對(duì)訓(xùn)練的參數(shù)加一個(gè)參數(shù)結(jié)構(gòu)體
opts.train = struct() ;
opts.train.gpus = [1];         %是否使用GPU
opts.train.batchSize = 8 ;        %batch大小
opts.train.numSubBatches = 4 ;
opts.train.learningRate = 1e-4 * [ones(1,10), 0.1*ones(1,5)];     %學(xué)習(xí)率

opts = vl_argparse(opts, varargin) ;
if ~isfield(opts.train, 'gpus'), opts.train.gpus = []; end

% -------------------------------------------------------------------------
%                                                             Prepare model
% -------------------------------------------------------------------------
net = load(opts.modelPath);
% 修改一下這個(gè)model族铆,進(jìn)入函數(shù)2
net = prepareDINet(net,opts);
% -------------------------------------------------------------------------
%                                                              Prepare data
% -------------------------------------------------------------------------
% 準(zhǔn)備數(shù)據(jù)格式
if exist(opts.imdbPath,'file')
  imdb = load(opts.imdbPath) ;
else
  imdb = cnn_image_setup_data('dataDir', opts.dataDir, 'lite', opts.lite) ;   %進(jìn)入函數(shù)3
  mkdir(opts.expDir) ;                %創(chuàng)建文件夾exp/image
  save(opts.imdbPath, '-struct', 'imdb') ;   %保存結(jié)構(gòu)體将硝,https://ww2.mathworks.cn/help/matlab/ref/save.html
end 

imdb.images.set = imdb.images.sets;

%把原網(wǎng)絡(luò)的類別(1000類)描述換成自己的描述(10類)
net.meta.classes.name = imdb.classes.name ;
net.meta.classes.description = imdb.classes.name ;

% % 求訓(xùn)練集的均值恭朗,進(jìn)入函數(shù)4
imageStatsPath = fullfile(opts.expDir, 'imageStats.mat') ;
if exist(imageStatsPath)
  load(imageStatsPath, 'averageImage') ;
else
    averageImage = getImageStats(opts, net.meta, imdb) ;
    save(imageStatsPath, 'averageImage') ;
end
% % 用新的均值改變均值
net.meta.normalization.averageImage = averageImage;
% -------------------------------------------------------------------------
%                                                                     Learn
% -------------------------------------------------------------------------
% 索引訓(xùn)練集==1  和測(cè)試集==3
opts.train.train = find(imdb.images.set==1) ;
opts.train.val = find(imdb.images.set==3) ;
% 訓(xùn)練
[net, info] = cnn_train_dag(net, imdb, getBatchFn(opts, net.meta), ...
                      'expDir', opts.expDir, ...
                      opts.train) ;

% -------------------------------------------------------------------------
%                                                                    Deploy
% -------------------------------------------------------------------------
% 保存訓(xùn)練完的網(wǎng)絡(luò)
%net = cnn_imagenet_deploy(net) ;
net = cnn_imagenet_deploy(net);
modelPath = fullfile(opts.expDir, 'net-deployed.mat');

net_ = net.saveobj() ;
save(modelPath, '-struct', 'net_') ;
clear net_ ;


2.model預(yù)調(diào)整

% -------------------------------------------------------------------------
function net = prepareDINet(net,opts)
% -------------------------------------------------------------------------
% 把 fc8層換成fc8l(原理還是有點(diǎn)搞不明白)
%對(duì)元胞數(shù)組中的每個(gè)元胞應(yīng)用函數(shù),官方幫助文件:https://ww2.mathworks.cn/help/matlab/ref/cellfun.html
fc8l = cellfun(@(a) strcmp(a.name, 'fc8'), net.layers)==1;

%%  note: 下面這個(gè)是類別數(shù),一定要和自己的類別數(shù)吻合(這里為10類)
nCls = 10;
sizeW = size(net.layers{fc8l}.weights{1});

%如果所需類別數(shù)和原網(wǎng)絡(luò)類別不一樣依疼。則用0初始化權(quán)重參數(shù) 
if sizeW(4)~=nCls
  net.layers{fc8l}.weights = {zeros(sizeW(1),sizeW(2),sizeW(3),nCls,'single'), ...
    zeros(1, nCls, 'single')};
end

% change loss  添加一個(gè)loss層用于訓(xùn)練
net.layers{end} = struct('name','loss', 'type','softmaxloss') ;

%將普通nn轉(zhuǎn)化成dagnn,比較靈活痰腮,參考鏈接:https://www.cnblogs.com/ironstark/p/6058090.html
net = dagnn.DagNN.fromSimpleNN(net, 'canonicalNames', true) ;

%添加error層
net.addLayer('top1err', dagnn.Loss('loss', 'classerror'), ...
    {'prediction','label'}, 'top1err') ;
net.addLayer('top5err', dagnn.Loss('loss', 'topkerror', ...
    'opts', {'topK',5}), ...
{'prediction','label'}, 'top5err') ;

3.數(shù)據(jù)集預(yù)處理

function imdb = cnn_image_setup_data(varargin)

opts.dataDir = fullfile('data','image') ;
opts.lite = false ;
opts = vl_argparse(opts, varargin) ;

% ------------------------------------------------------------------------
%                                                  Load categories metadata
% -------------------------------------------------------------------------

metaPath = fullfile(opts.dataDir, 'classInd.txt') ;

fprintf('using metadata %s\n', metaPath) ;
tmp = importdata(metaPath);
nCls = numel(tmp);
% 判斷類別與設(shè)定的是否一樣 10為樣本的類別總數(shù)(自己的數(shù)據(jù)集需要修改)
if nCls ~= 10
  error('Wrong meta file %s',metaPath);
end
% 將名字分離出來
cats = cell(1,nCls);
for i=1:numel(tmp)
  t = strsplit(tmp{i});
  cats{i} = t{2};
end
% 數(shù)據(jù)集文件夾選擇
imdb.classes.name = cats ;                      %類別名稱
imdb.imageDir.train = fullfile(opts.dataDir, 'train') ;         %訓(xùn)練數(shù)據(jù)地址
imdb.imageDir.test = fullfile(opts.dataDir, 'test') ;            %測(cè)試數(shù)據(jù)地址

%% -----------------------------------------------------------------
%                                              load image names and labels
% -------------------------------------------------------------------------

name = {};
labels = {} ;
imdb.images.sets = [] ;
%%
fprintf('searching training images ...\n') ;

% 導(dǎo)入訓(xùn)練類別標(biāo)簽
train_label_path = fullfile(opts.dataDir, 'train_label.txt') ;
train_label_temp = importdata(train_label_path);
temp_l = train_label_temp.data;
for i=1:numel(temp_l)
    train_label{i} = temp_l(i);
end
if length(train_label) ~= length(dir(fullfile(imdb.imageDir.train, '*.jpg')))
    error('training data is not equal to its label!!!');
end

i = 1;
for d = dir(fullfile(imdb.imageDir.train, '*.jpg'))'
    name{end+1} = d.name;
    labels{end+1} = train_label{i} ;
    if mod(numel(name), 10) == 0, fprintf('.') ; end
    if mod(numel(name), 500) == 0, fprintf('\n') ; end
    imdb.images.sets(end+1) = 1;%train
    i = i+1;
end
%%
fprintf('searching testing images ...\n') ;

% 導(dǎo)入測(cè)試類別標(biāo)簽
test_label_path = fullfile(opts.dataDir, 'test_label.txt') ;
test_label_temp = importdata(test_label_path);
temp_l = test_label_temp.data;
for i=1:numel(temp_l)
    test_label{i} = temp_l(i);
end
if length(test_label) ~= length(dir(fullfile(imdb.imageDir.test, '*.jpg')))
    error('testing data is not equal to its label!!!');
end
i = 1;
for d = dir(fullfile(imdb.imageDir.test, '*.jpg'))'
    name{end+1} = d.name;
    labels{end+1} = test_label{i} ;
    if mod(numel(name), 10) == 0, fprintf('.') ; end
    if mod(numel(name), 500) == 0, fprintf('\n') ; end
    imdb.images.sets(end+1) = 3;%test
    i = i+1;
end


labels = horzcat(labels{:}) ;    %horzcat水平串聯(lián)數(shù)組
imdb.images.id = 1:numel(name) ;    %給圖像編號(hào)
imdb.images.name = name ;          %圖像文件名
imdb.images.label = labels ;           %圖像標(biāo)簽

輸出的imdb數(shù)據(jù)結(jié)構(gòu)如下圖所示。


4.求樣本均值

% 求訓(xùn)練樣本的均值
% -------------------------------------------------------------------------
function averageImage = getImageStats(opts, meta, imdb)
% -------------------------------------------------------------------------
train = find(imdb.images.set == 1) ;  %找出識(shí)別號(hào)為1的(代表訓(xùn)練集)
batch = 1:length(train);
fn = getBatchFn(opts, meta) ;            %步入函數(shù)5
train = train(1: 100: end);              %按照100個(gè)為一個(gè)batch
avg = {};
for i = 1:length(train)
    temp = fn(imdb, batch(train(i):train(i)+99)) ;   %temp為圖像+標(biāo)簽的序列
    temp = temp{2};                                  %只為圖像 224x224x3x100
    avg{end+1} = mean(temp, 4) ;                     %將其按照第四維求平均,得到10個(gè)平均圖像
end

averageImage = mean(cat(4,avg{:}),4) ;            %再將這十個(gè)圖像求平均
% 將GPU格式的轉(zhuǎn)化為cpu格式的保存起來(如果有用GPU)
averageImage = gather(averageImage);

5.定義Fn函數(shù)

function fn = getBatchFn(opts, meta)
% -------------------------------------------------------------------------
useGpu = numel(opts.train.gpus) > 0 ;   %是否使用GPU

bopts.numThreads = opts.numFetchThreads ;            %12
bopts.imageSize = meta.normalization.imageSize ;     %[224,224,3,10]
bopts.border = meta.normalization.border ;           %[32,32]
% bopts.averageImage = []; 
bopts.averageImage = meta.normalization.averageImage ;      %224x224x3 double
% bopts.rgbVariance = meta.augmentation.rgbVariance ;
% bopts.transformation = meta.augmentation.transformation ;

fn = @(x,y) getDagNNBatch(bopts,useGpu,x,y) ;   %定義function:fn=@(x,y)getDagNNBatch(bopts,useGpu,x,y)

6.input = {'input', im, 'label', labels}

產(chǎn)生input cell律罢,input:224x224x3x100, labels:1x100膀值。將100個(gè)圖像以及標(biāo)簽打包

function inputs = getDagNNBatch(opts, useGpu, imdb, batch)
% -------------------------------------------------------------------------
% 判斷讀入數(shù)據(jù)為訓(xùn)練還是測(cè)試
for i = 1:length(batch)
    if imdb.images.set(batch(i)) == 1 %1為訓(xùn)練索引文件夾
        images(i) = strcat([imdb.imageDir.train filesep] , imdb.images.name(batch(i)));    %filesep為文件分割符,strcat為橫向連接字符串
    else
        images(i) = strcat([imdb.imageDir.test filesep] , imdb.images.name(batch(i)));
    end
end
isVal = ~isempty(batch) && imdb.images.set(batch(1)) ~= 1 ;

if ~isVal
  % training

  im = cnn_imagenet_get_batch(images, opts, ...
                               'prefetch', nargout == 0) ;     %步入函數(shù)7
else
  % validation: disable data augmentation
  im = cnn_imagenet_get_batch(images, opts, ...
                              'prefetch', nargout == 0, ...
                              'transformation', 'none') ;
end

if nargout > 0
  if useGpu
    im = gpuArray(im) ;
  end
  labels = imdb.images.label(batch) ;
  inputs = {'input', im, 'label', labels} ;
end

7.打包圖像误辑,預(yù)處理(到規(guī)定網(wǎng)絡(luò)大小和平均值)

function imo = cnn_imagenet_get_batch(images, varargin)
% CNN_IMAGENET_GET_BATCH  Load, preprocess, and pack images for CNN evaluation

opts.imageSize = [227, 227] ;
opts.border = [29, 29] ;
opts.keepAspect = true ;
opts.numAugments = 1 ;
opts.transformation = 'none' ;
opts.averageImage = [] ;
opts.rgbVariance = zeros(0,3,'single') ;
opts.interpolation = 'bilinear' ;
opts.numThreads = 1 ;
opts.prefetch = false ;
opts = vl_argparse(opts, varargin);

% fetch is true if images is a list of filenames (instead of
% a cell array of images)
fetch = numel(images) >= 1 && ischar(images{1}) ;

% prefetch is used to load images in a separate thread
prefetch = fetch & opts.prefetch ;

if prefetch
  vl_imreadjpeg(images, 'numThreads', opts.numThreads, 'prefetch') ;
  imo = [] ;
  return ;
end
if fetch
  im = vl_imreadjpeg(images,'numThreads', opts.numThreads) ; %批量讀取圖像文件
else
  im = images ;
end

tfs = [] ;      %定義變換
switch opts.transformation
  case 'none'
    tfs = [
      .5 ;
      .5 ;
       0 ] ;
  case 'f5'
    tfs = [...
      .5 0 0 1 1 .5 0 0 1 1 ;
      .5 0 1 0 1 .5 0 1 0 1 ;
       0 0 0 0 0  1 1 1 1 1] ;
  case 'f25'
    [tx,ty] = meshgrid(linspace(0,1,5)) ;
    tfs = [tx(:)' ; ty(:)' ; zeros(1,numel(tx))] ;
    tfs_ = tfs ;
    tfs_(3,:) = 1 ;
    tfs = [tfs,tfs_] ;
  case 'stretch'
  otherwise
    error('Uknown transformations %s', opts.transformation) ;
end
[~,transformations] = sort(rand(size(tfs,2), numel(images)), 1) ;

if ~isempty(opts.rgbVariance) && isempty(opts.averageImage)
  opts.averageImage = zeros(1,1,3) ;
end
if numel(opts.averageImage) == 3
  opts.averageImage = reshape(opts.averageImage, 1,1,3) ;
end

imo = zeros(opts.imageSize(1), opts.imageSize(2), 3, ...
            numel(images)*opts.numAugments, 'single') ;

si = 1 ;
for i=1:numel(images)

  % acquire image
  if isempty(im{i})
    imt = imread(images{i}) ;
    imt = single(imt) ; % faster than im2single (and multiplies by 255)
  else
    imt = im{i} ;
  end
  if size(imt,3) == 1
    imt = cat(3, imt, imt, imt) ;
  end

  % resize
  w = size(imt,2) ;
  h = size(imt,1) ;
  factor = [(opts.imageSize(1)+opts.border(1))/h ...
            (opts.imageSize(2)+opts.border(2))/w];

  if opts.keepAspect
    factor = max(factor) ;
  end
  if any(abs(factor - 1) > 0.0001) %只要形變因子不為1沧踏,則resize圖像
    imt = imresize(imt, ...
                   'scale', factor, ...
                   'method', opts.interpolation) ;
  end

  % crop & flip
  w = size(imt,2) ;
  h = size(imt,1) ;
  for ai = 1:opts.numAugments
    switch opts.transformation
      case 'stretch'
        sz = round(min(opts.imageSize(1:2)' .* (1-0.1+0.2*rand(2,1)), [h;w])) ;
        dx = randi(w - sz(2) + 1, 1) ;
        dy = randi(h - sz(1) + 1, 1) ;
        flip = rand > 0.5 ;
      otherwise
        tf = tfs(:, transformations(mod(ai-1, numel(transformations)) + 1)) ;
        sz = opts.imageSize(1:2) ;
        dx = floor((w - sz(2)) * tf(2)) + 1 ;
        dy = floor((h - sz(1)) * tf(1)) + 1 ;
        flip = tf(3) ;
    end
    sx = round(linspace(dx, sz(2)+dx-1, opts.imageSize(2))) ;
    sy = round(linspace(dy, sz(1)+dy-1, opts.imageSize(1))) ;
    if flip, sx = fliplr(sx) ; end

    if ~isempty(opts.averageImage)
      offset = opts.averageImage ;
      if ~isempty(opts.rgbVariance)
        offset = bsxfun(@plus, offset, reshape(opts.rgbVariance * randn(3,1), 1,1,3)) ;
      end
      imo(:,:,:,si) = bsxfun(@minus, imt(sy,sx,:), offset) ;
    else
      imo(:,:,:,si) = imt(sy,sx,:) ;
    end
    si = si + 1 ;
  end
end

8.訓(xùn)練函數(shù)

function [net,stats] = cnn_train_dag(net, imdb, getBatch, varargin)
%CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper
%    CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with
%    the DagNN wrapper instead of the SimpleNN wrapper.

% Copyright (C) 2014-16 Andrea Vedaldi.
% All rights reserved.
%
% This file is part of the VLFeat library and is made available under
% the terms of the BSD license (see the COPYING file).
addpath(fullfile(vl_rootnn, 'examples'));

opts.expDir = fullfile('data','exp') ;
opts.continue = true ;
opts.batchSize = 256 ;
opts.numSubBatches = 1 ;
opts.train = [] ;
opts.val = [] ;
opts.gpus = [1] ;
opts.prefetch = false ;
opts.epochSize = inf;
opts.numEpochs = 20 ;
opts.learningRate = 0.001 ;
opts.weightDecay = 0.0005 ;

opts.solver = [] ;  % 空集代表使用 SGD優(yōu)化方法
[opts, varargin] = vl_argparse(opts, varargin) ;
if ~isempty(opts.solver)
  assert(isa(opts.solver, 'function_handle') && nargout(opts.solver) == 2,...
    'Invalid solver; expected a function handle with two outputs.') ;
  % Call without input arguments, to get default options
  opts.solverOpts = opts.solver() ;
end

opts.momentum = 0.9 ;
opts.saveSolverState = true ;
opts.nesterovUpdate = false ;
opts.randomSeed = 0 ;
opts.profile = false ;
opts.parameterServer.method = 'mmap' ;
opts.parameterServer.prefix = 'mcn' ;

opts.derOutputs = {'objective', 1} ;
opts.extractStatsFn = @extractStats ;
opts.plotStatistics = true;
opts.postEpochFn = [] ;  % postEpochFn(net,params,state) called after each epoch; can return a new learning rate, 0 to stop, [] for no change
opts = vl_argparse(opts, varargin) ;

if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end
if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
if isscalar(opts.train) && isnumeric(opts.train) && isnan(opts.train)
  opts.train = [] ;
end
if isscalar(opts.val) && isnumeric(opts.val) && isnan(opts.val)
  opts.val = [] ;
end

% -------------------------------------------------------------------------
%                                                            Initialization
% -------------------------------------------------------------------------

evaluateMode = isempty(opts.train) ;
if ~evaluateMode
  if isempty(opts.derOutputs)
    error('DEROUTPUTS must be specified when training.\n') ;
  end
end

% -------------------------------------------------------------------------
%                                                        Train and validate
% -------------------------------------------------------------------------

modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep));
modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ;

start = opts.continue * findLastCheckpoint(opts.expDir) ; %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%卡在這里
if start >= 1
  fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
  [net, state, stats] = loadState(modelPath(start)) ;
else
  state = [] ;
end

for epoch=start+1:opts.numEpochs

  % Set the random seed based on the epoch and opts.randomSeed.
  % This is important for reproducibility, including when training
  % is restarted from a checkpoint.

  rng(epoch + opts.randomSeed) ;
  prepareGPUs(opts, epoch == start+1) ;

  % Train for one epoch.
  params = opts ;
  params.epoch = epoch ;
  params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
  params.train = opts.train(randperm(numel(opts.train))) ; % shuffle
  params.train = params.train(1:min(opts.epochSize, numel(opts.train)));
  params.val = opts.val(randperm(numel(opts.val))) ;
  params.imdb = imdb ;
  params.getBatch = getBatch ;

  if numel(opts.gpus) <= 1
    [net, state] = processEpoch(net, state, params, 'train') ;
    [net, state] = processEpoch(net, state, params, 'val') ;
    if ~evaluateMode
      saveState(modelPath(epoch), net, state) ;
    end
    lastStats = state.stats ;
  else
    spmd
      [net, state] = processEpoch(net, state, params, 'train') ;
      [net, state] = processEpoch(net, state, params, 'val') ;
      if labindex == 1 && ~evaluateMode
        saveState(modelPath(epoch), net, state) ;
      end
      lastStats = state.stats ;
    end
    lastStats = accumulateStats(lastStats) ;
  end

  stats.train(epoch) = lastStats.train ;
  stats.val(epoch) = lastStats.val ;
  clear lastStats ;
  saveStats(modelPath(epoch), stats) ;

  if opts.plotStatistics
    switchFigure(1) ; clf ;
    plots = setdiff(...
      cat(2,...
      fieldnames(stats.train)', ...
      fieldnames(stats.val)'), {'num', 'time'}) ;
    for p = plots
      p = char(p) ;
      values = zeros(0, epoch) ;
      leg = {} ;
      for f = {'train', 'val'}
        f = char(f) ;
        if isfield(stats.(f), p)
          tmp = [stats.(f).(p)] ;
          values(end+1,:) = tmp(1,:)' ;
          leg{end+1} = f ;
        end
      end
      subplot(1,numel(plots),find(strcmp(p,plots))) ;
      plot(1:epoch, values','o-') ;
      xlabel('epoch') ;
      title(p) ;
      legend(leg{:}) ;
      grid on ;
    end
    drawnow ;
    print(1, modelFigPath, '-dpdf') ;
  end
  
  if ~isempty(opts.postEpochFn)
    if nargout(opts.postEpochFn) == 0
      opts.postEpochFn(net, params, state) ;
    else
      lr = opts.postEpochFn(net, params, state) ;
      if ~isempty(lr), opts.learningRate = lr; end
      if opts.learningRate == 0, break; end
    end
  end
end

% With multiple GPUs, return one copy
if isa(net, 'Composite'), net = net{1} ; end

% -------------------------------------------------------------------------
function [net, state] = processEpoch(net, state, params, mode)
% -------------------------------------------------------------------------
% Note that net is not strictly needed as an output argument as net
% is a handle class. However, this fixes some aliasing issue in the
% spmd caller.

% initialize with momentum 0
if isempty(state) || isempty(state.solverState)
  state.solverState = cell(1, numel(net.params)) ;
  state.solverState(:) = {0} ;
end

% move CNN  to GPU as needed
numGpus = numel(params.gpus) ;
if numGpus >= 1
  net.move('gpu') ;
  for i = 1:numel(state.solverState)
    s = state.solverState{i} ;
    if isnumeric(s)
      state.solverState{i} = gpuArray(s) ;
    elseif isstruct(s)
      state.solverState{i} = structfun(@gpuArray, s, 'UniformOutput', false) ;
    end
  end
end
if numGpus > 1
  parserv = ParameterServer(params.parameterServer) ;
  net.setParameterServer(parserv) ;
else
  parserv = [] ;
end

% profile
if params.profile
  if numGpus <= 1
    profile clear ;
    profile on ;
  else
    mpiprofile reset ;
    mpiprofile on ;
  end
end

num = 0 ;
epoch = params.epoch ;
subset = params.(mode) ;
adjustTime = 0 ;

stats.num = 0 ; % return something even if subset = []
stats.time = 0 ;

start = tic ;
for t=1:params.batchSize:numel(subset)
  fprintf('%s: epoch %02d: %3d/%3d:', mode, epoch, ...
          fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize)) ;
  batchSize = min(params.batchSize, numel(subset) - t + 1) ;

  for s=1:params.numSubBatches
    % get this image batch and prefetch the next
    batchStart = t + (labindex-1) + (s-1) * numlabs ;
    batchEnd = min(t+params.batchSize-1, numel(subset)) ;
    batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
    num = num + numel(batch) ;
    if numel(batch) == 0, continue ; end

    inputs = params.getBatch(params.imdb, batch) ;

    if params.prefetch
      if s == params.numSubBatches
        batchStart = t + (labindex-1) + params.batchSize ;
        batchEnd = min(t+2*params.batchSize-1, numel(subset)) ;
      else
        batchStart = batchStart + numlabs ;
      end
      nextBatch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
      params.getBatch(params.imdb, nextBatch) ;
    end

    if strcmp(mode, 'train')
      net.mode = 'normal' ;
      net.accumulateParamDers = (s ~= 1) ;
      net.eval(inputs, params.derOutputs, 'holdOn', s < params.numSubBatches) ;
    else
      net.mode = 'test' ;
      net.eval(inputs) ;
    end
  end

  % Accumulate gradient.
  if strcmp(mode, 'train')
    if ~isempty(parserv), parserv.sync() ; end
    state = accumulateGradients(net, state, params, batchSize, parserv) ;
  end

  % Get statistics.
  time = toc(start) + adjustTime ;
  batchTime = time - stats.time ;
  stats.num = num ;
  stats.time = time ;
  stats = params.extractStatsFn(stats,net) ;
  currentSpeed = batchSize / batchTime ;
  averageSpeed = (t + batchSize - 1) / time ;
  if t == 3*params.batchSize + 1
    % compensate for the first three iterations, which are outliers
    adjustTime = 4*batchTime - time ;
    stats.time = time + adjustTime ;
  end

  fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ;
  for f = setdiff(fieldnames(stats)', {'num', 'time'})
    f = char(f) ;
    fprintf(' %s: %.3f', f, stats.(f)) ;
  end
  fprintf('\n') ;
end

% Save back to state.
state.stats.(mode) = stats ;
if params.profile
  if numGpus <= 1
    state.prof.(mode) = profile('info') ;
    profile off ;
  else
    state.prof.(mode) = mpiprofile('info');
    mpiprofile off ;
  end
end
if ~params.saveSolverState
  state.solverState = [] ;
else
  for i = 1:numel(state.solverState)
    s = state.solverState{i} ;
    if isnumeric(s)
      state.solverState{i} = gather(s) ;
    elseif isstruct(s)
      state.solverState{i} = structfun(@gather, s, 'UniformOutput', false) ;
    end
  end
end

net.reset() ;
net.move('cpu') ;

% -------------------------------------------------------------------------
function state = accumulateGradients(net, state, params, batchSize, parserv)
% -------------------------------------------------------------------------
numGpus = numel(params.gpus) ;
otherGpus = setdiff(1:numGpus, labindex) ;

for p=1:numel(net.params)

  if ~isempty(parserv)
    parDer = parserv.pullWithIndex(p) ;
  else
    parDer = net.params(p).der ;
  end

  switch net.params(p).trainMethod

    case 'average' % mainly for batch normalization
      thisLR = net.params(p).learningRate ;
      net.params(p).value = vl_taccum(...
          1 - thisLR, net.params(p).value, ...
          (thisLR/batchSize/net.params(p).fanout),  parDer) ;

    case 'gradient'
      thisDecay = params.weightDecay * net.params(p).weightDecay ;
      thisLR = params.learningRate * net.params(p).learningRate ;

      if thisLR>0 || thisDecay>0
        % Normalize gradient and incorporate weight decay.
        parDer = vl_taccum(1/batchSize, parDer, ...
                           thisDecay, net.params(p).value) ;

        if isempty(params.solver)
          % Default solver is the optimised SGD.
          % Update momentum.
          state.solverState{p} = vl_taccum(...
            params.momentum, state.solverState{p}, ...
            -1, parDer) ;

          % Nesterov update (aka one step ahead).
          if params.nesterovUpdate
            delta = params.momentum * state.solverState{p} - parDer ;
          else
            delta = state.solverState{p} ;
          end

          % Update parameters.
          net.params(p).value = vl_taccum(...
            1,  net.params(p).value, thisLR, delta) ;

        else
          % call solver function to update weights
          [net.params(p).value, state.solverState{p}] = ...
            params.solver(net.params(p).value, state.solverState{p}, ...
            parDer, params.solverOpts, thisLR) ;
        end
      end
    otherwise
      error('Unknown training method ''%s'' for parameter ''%s''.', ...
        net.params(p).trainMethod, ...
        net.params(p).name) ;
  end
end

% -------------------------------------------------------------------------
function stats = accumulateStats(stats_)
% -------------------------------------------------------------------------

for s = {'train', 'val'}
  s = char(s) ;
  total = 0 ;

  % initialize stats stucture with same fields and same order as
  % stats_{1}
  stats__ = stats_{1} ;
  names = fieldnames(stats__.(s))' ;
  values = zeros(1, numel(names)) ;
  fields = cat(1, names, num2cell(values)) ;
  stats.(s) = struct(fields{:}) ;

  for g = 1:numel(stats_)
    stats__ = stats_{g} ;
    num__ = stats__.(s).num ;
    total = total + num__ ;

    for f = setdiff(fieldnames(stats__.(s))', 'num')
      f = char(f) ;
      stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;

      if g == numel(stats_)
        stats.(s).(f) = stats.(s).(f) / total ;
      end
    end
  end
  stats.(s).num = total ;
end

% -------------------------------------------------------------------------
function stats = extractStats(stats, net)
% -------------------------------------------------------------------------
sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ;
for i = 1:numel(sel)
  if net.layers(sel(i)).block.ignoreAverage, continue; end;
  stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ;
end

% -------------------------------------------------------------------------
function saveState(fileName, net_, state)
% -------------------------------------------------------------------------
net = net_.saveobj() ;
save(fileName, 'net', 'state') ;

% -------------------------------------------------------------------------
function saveStats(fileName, stats)
% -------------------------------------------------------------------------
if exist(fileName)
  save(fileName, 'stats', '-append') ;
else
  save(fileName, 'stats') ;
end

% -------------------------------------------------------------------------
function [net, state, stats] = loadState(fileName)
% -------------------------------------------------------------------------
load(fileName, 'net', 'state', 'stats') ;
net = dagnn.DagNN.loadobj(net) ;
if isempty(whos('stats'))
  error('Epoch ''%s'' was only partially saved. Delete this file and try again.', ...
        fileName) ;
end

% -------------------------------------------------------------------------
function epoch = findLastCheckpoint(modelDir)
% -------------------------------------------------------------------------
list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
epoch = max([epoch 0]) ;

% -------------------------------------------------------------------------
function switchFigure(n)
% -------------------------------------------------------------------------
if get(0,'CurrentFigure') ~= n
  try
    set(0,'CurrentFigure',n) ;
  catch
    figure(n) ;
  end
end

% -------------------------------------------------------------------------
function clearMex()
% -------------------------------------------------------------------------
clear vl_tmove vl_imreadjpeg ;

% -------------------------------------------------------------------------
function prepareGPUs(opts, cold)
% -------------------------------------------------------------------------
numGpus = numel(opts.gpus) ;
if numGpus > 1
  % check parallel pool integrity as it could have timed out
  pool = gcp('nocreate') ;
  if ~isempty(pool) && pool.NumWorkers ~= numGpus
    delete(pool) ;
  end
  pool = gcp('nocreate') ;
  if isempty(pool)
    parpool('local', numGpus) ;
    cold = true ;
  end

end
if numGpus >= 1 && cold
  fprintf('%s: resetting GPU\n', mfilename)
  clearMex() ;
  if numGpus == 1
    gpuDevice(opts.gpus)
  else
    spmd
      clearMex() ;
      gpuDevice(opts.gpus(labindex))
    end
  end
end

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市巾钉,隨后出現(xiàn)的幾起案子翘狱,更是在濱河造成了極大的恐慌,老刑警劉巖睛琳,帶你破解...
    沈念sama閱讀 221,406評(píng)論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件盒蟆,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡师骗,警方通過查閱死者的電腦和手機(jī)历等,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,395評(píng)論 3 398
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來辟癌,“玉大人寒屯,你說我怎么就攤上這事。” “怎么了寡夹?”我有些...
    開封第一講書人閱讀 167,815評(píng)論 0 360
  • 文/不壞的土叔 我叫張陵处面,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我菩掏,道長(zhǎng)魂角,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 59,537評(píng)論 1 296
  • 正文 為了忘掉前任智绸,我火速辦了婚禮野揪,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘瞧栗。我一直安慰自己斯稳,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,536評(píng)論 6 397
  • 文/花漫 我一把揭開白布迹恐。 她就那樣靜靜地躺著挣惰,像睡著了一般。 火紅的嫁衣襯著肌膚如雪殴边。 梳的紋絲不亂的頭發(fā)上憎茂,一...
    開封第一講書人閱讀 52,184評(píng)論 1 308
  • 那天,我揣著相機(jī)與錄音找都,去河邊找鬼唇辨。 笑死,一個(gè)胖子當(dāng)著我的面吹牛能耻,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播亡驰,決...
    沈念sama閱讀 40,776評(píng)論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼晓猛,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了凡辱?” 一聲冷哼從身側(cè)響起戒职,我...
    開封第一講書人閱讀 39,668評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎透乾,沒想到半個(gè)月后洪燥,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,212評(píng)論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡乳乌,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,299評(píng)論 3 340
  • 正文 我和宋清朗相戀三年捧韵,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片汉操。...
    茶點(diǎn)故事閱讀 40,438評(píng)論 1 352
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡再来,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情芒篷,我是刑警寧澤搜变,帶...
    沈念sama閱讀 36,128評(píng)論 5 349
  • 正文 年R本政府宣布,位于F島的核電站针炉,受9級(jí)特大地震影響挠他,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜篡帕,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,807評(píng)論 3 333
  • 文/蒙蒙 一殖侵、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧赂苗,春花似錦愉耙、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,279評(píng)論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至败砂,卻和暖如春赌渣,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背昌犹。 一陣腳步聲響...
    開封第一講書人閱讀 33,395評(píng)論 1 272
  • 我被黑心中介騙來泰國(guó)打工坚芜, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人斜姥。 一個(gè)月前我還...
    沈念sama閱讀 48,827評(píng)論 3 376
  • 正文 我出身青樓鸿竖,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親铸敏。 傳聞我的和親對(duì)象是個(gè)殘疾皇子缚忧,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,446評(píng)論 2 359

推薦閱讀更多精彩內(nèi)容