Faster RCNN源碼解析(1)

Faster RCNN paper https://arxiv.org/abs/1506.01497

Faster RCNN 可以看成 Region Proposal + RCNN相結(jié)合的一個網(wǎng)絡(luò)模型棍丐,即區(qū)域生成網(wǎng)絡(luò)RPN+ 目標(biāo)識別 image classifier包晰。
Faster RCNN 在某種意義上解決了Region Proposal耗費(fèi)大量時間的問題。

具體流程如下:

  1. 通過特征提取層坯苹,即一系列conv, relu, pooling得到feature map, 該map被用于RPN層以及后續(xù)的全連接層壳繁。
  2. 在feature map的每個像素上生成多個Anchor(定義請看下文)绷杜。
  3. 利用RPN層生成region proposals垃瞧,該層主要有兩個目的: 1, 通過softmax判斷anchors屬于foreground 還是background, 再利用bounding box regression 修正anchors獲取相對精確的proposals甩十。
  4. 從之前生成的feature map 和RPN層獲取的 proposals獲取proposal feature map, 然后經(jīng)過ROI Pooling層船庇。
  5. 最后將ROI Pooling層獲取的結(jié)果通過全連接層來計(jì)算proposals的類別吭产,同時再次bounding box regression獲取檢測框的最終位置。
流程圖

下面我們從源碼的角度來分析這些模塊

一鸭轮、數(shù)據(jù)載入

tools/trainval_net.py
# train set
  imdb, roidb = combined_roidb(args.imdb_name)
  print('{:d} roidb entries'.format(len(roidb)))
  • imdb:是一個base class, 實(shí)例有pascal_voc臣淤、coco。主要對所有圖片的類別,名稱,路徑做了一個匯總窃爷。
  • roidb:是imdb的一個屬性邑蒋,里面是一個dictionary, 包含了它的GTbox,真實(shí)標(biāo)簽的信息以及翻轉(zhuǎn)標(biāo)簽。
tools/trainval_net.py
def combined_roidb(imdb_names):
  """
  Combine multiple roidbs
  """

  def get_roidb(imdb_name):
    imdb = get_imdb(imdb_name)
    print('Loaded dataset `{:s}` for training'.format(imdb.name))
    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))
    roidb = get_training_roidb(imdb)
    return roidb

  roidbs = [get_roidb(s) for s in imdb_names.split('+')]
  roidb = roidbs[0]
  if len(roidbs) > 1:
    for r in roidbs[1:]:
      roidb.extend(r)
    tmp = get_imdb(imdb_names.split('+')[1])
    imdb = datasets.imdb.imdb(imdb_names, tmp.classes)
  else:
    imdb = get_imdb(imdb_names)
  return imdb, roidb

這里的參數(shù)imdb_names是啟動的時指定的,例如pascal_voc, pascal_voc_0712, coco

lib/datasets/factory.py
# Set up voc_<year>_<split> 
for year in ['2007', '2012']:
  for split in ['train', 'val', 'trainval', 'test']:
    name = 'voc_{}_{}'.format(year, split)
    __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))

for year in ['2007', '2012']:
  for split in ['train', 'val', 'trainval', 'test']:
    name = 'voc_{}_{}_diff'.format(year, split)
    __sets[name] = (lambda split=split, year=year: pascal_voc(split, year, use_diff=True))

# Set up coco_2014_<split>
for year in ['2014']:
  for split in ['train', 'val', 'minival', 'valminusminival', 'trainval']:
    name = 'coco_{}_{}'.format(year, split)
    __sets[name] = (lambda split=split, year=year: coco(split, year))

# Set up coco_2015_<split>
for year in ['2015']:
  for split in ['test', 'test-dev']:
    name = 'coco_{}_{}'.format(year, split)
    __sets[name] = (lambda split=split, year=year: coco(split, year))


def get_imdb(name):
  """Get an imdb (image database) by name."""
  if name not in __sets:
    raise KeyError('Unknown dataset: {}'.format(name))
  return __sets[name]()

工廠類按厘,負(fù)責(zé)生成不同的imdb的實(shí)例医吊,我們以pascal_voc為例。

lib/datasets/pascal_voc.py
class pascal_voc(imdb):
  def __init__(self, image_set, year, use_diff=False):
    name = 'voc_' + year + '_' + image_set
    if use_diff:
      name += '_diff'
    imdb.__init__(self, name)
    self._year = year
    self._image_set = image_set
    self._devkit_path = self._get_default_path()
    self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
    self._classes = ('__background__',  # always index 0
                     'aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor')
    self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))
    self._image_ext = '.jpg'
    self._image_index = self._load_image_set_index()
    # Default to roidb handler
    self._roidb_handler = self.gt_roidb
    self._salt = str(uuid.uuid4())
    self._comp_id = 'comp4'

    # PASCAL specific config options
    self.config = {'cleanup': True,
                   'use_salt': True,
                   'use_diff': use_diff,
                   'matlab_eval': False,
                   'rpn_file': None}

    assert os.path.exists(self._devkit_path), \
      'VOCdevkit path does not exist: {}'.format(self._devkit_path)
    assert os.path.exists(self._data_path), \
      'Path does not exist: {}'.format(self._data_path)

這個類有兩個蠻重要的成員

  • self._image_index: 圖片數(shù)據(jù)集(例如 /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt)中的索引數(shù)組逮京。
  • self._roidb_handler: ground truth感興趣的區(qū)域卿堂,如果存在緩存文件,從緩存中讀取懒棉,否則調(diào)用self._load_pascal_annotation()從xml文件中讀取圖像的標(biāo)注草描。
lib/datasets/pascal_voc.py
  def _load_image_set_index(self):
    """
    Load the indexes listed in this dataset's image set file.
    """
    # Example path to image set file:
    # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
    image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                  self._image_set + '.txt')
    assert os.path.exists(image_set_file), \
      'Path does not exist: {}'.format(image_set_file)
    with open(image_set_file) as f:
      image_index = [x.strip() for x in f.readlines()]
    return image_index
lib/datasets/pascal_voc.py
  def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
    if os.path.exists(cache_file):
      with open(cache_file, 'rb') as fid:
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb

    gt_roidb = [self._load_pascal_annotation(index)
                for index in self.image_index]
    with open(cache_file, 'wb') as fid:
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb

函數(shù)返回一個字典:

  • boxes: 二維數(shù)組 [num_objs,4], 每行4個值分別表示邊界框的左上和右下角坐標(biāo)。
  • gt_classes: 一位數(shù)組 [num_objs] 每行的數(shù)字代表類別策严。
  • gt_overlaps: 二維數(shù)組[num_objs,num_classes] 每行表示一個物體穗慕,在該物體對應(yīng)的哪一列設(shè)為1.0。
  • flipped: 是否經(jīng)過了翻轉(zhuǎn)享钞。
  • seg_areas: 以為數(shù)組[num_objs] 邊界框包含區(qū)域的面積揍诽。
lib/datasets/pascal_voc.py
  def _load_pascal_annotation(self, index):
    """
    Load image and bounding boxes info from XML file in the PASCAL VOC
    format.
    """
    filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
    tree = ET.parse(filename)
    objs = tree.findall('object')
    if not self.config['use_diff']:
      # Exclude the samples labeled as difficult
      non_diff_objs = [
        obj for obj in objs if int(obj.find('difficult').text) == 0]
      # if len(non_diff_objs) != len(objs):
      #     print 'Removed {} difficult objects'.format(
      #         len(objs) - len(non_diff_objs))
      objs = non_diff_objs
    num_objs = len(objs)

    boxes = np.zeros((num_objs, 4), dtype=np.uint16)
    gt_classes = np.zeros((num_objs), dtype=np.int32)
    overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
    # "Seg" area for pascal is just the box area
    seg_areas = np.zeros((num_objs), dtype=np.float32)

    # Load object bounding boxes into a data frame.
    for ix, obj in enumerate(objs):
      bbox = obj.find('bndbox')
      # Make pixel indexes 0-based
      x1 = float(bbox.find('xmin').text) - 1
      y1 = float(bbox.find('ymin').text) - 1
      x2 = float(bbox.find('xmax').text) - 1
      y2 = float(bbox.find('ymax').text) - 1
      cls = self._class_to_ind[obj.find('name').text.lower().strip()]
      boxes[ix, :] = [x1, y1, x2, y2]
      gt_classes[ix] = cls
      overlaps[ix, cls] = 1.0
      seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

    overlaps = scipy.sparse.csr_matrix(overlaps)

    return {'boxes': boxes,
            'gt_classes': gt_classes,
            'gt_overlaps': overlaps,
            'flipped': False,
            'seg_areas': seg_areas}

方法def get_roidb(imdb_name)中還有一個重要的步驟: roidb = get_training_roidb(imdb)
下面的函數(shù)返回一個用于訓(xùn)練的roidb

lib/model/train_val.py
def get_training_roidb(imdb):
  """Returns a roidb (Region of Interest database) for use in training."""
  if cfg.TRAIN.USE_FLIPPED:
    print('Appending horizontally-flipped training examples...')
    imdb.append_flipped_images()
    print('done')

  print('Preparing training data...')
  rdl_roidb.prepare_roidb(imdb)
  print('done')

  return imdb.roidb

下面的函數(shù)就是將圖片做一個翻轉(zhuǎn),并且添加到原來的roidb中

lib/datasets/imdb.py
  def append_flipped_images(self):
    num_images = self.num_images
    widths = self._get_widths()
    for i in range(num_images):
      boxes = self.roidb[i]['boxes'].copy()
      oldx1 = boxes[:, 0].copy()
      oldx2 = boxes[:, 2].copy()
      boxes[:, 0] = widths[i] - oldx2 - 1
      boxes[:, 2] = widths[i] - oldx1 - 1
      assert (boxes[:, 2] >= boxes[:, 0]).all()
      entry = {'boxes': boxes,
               'gt_overlaps': self.roidb[i]['gt_overlaps'],
               'gt_classes': self.roidb[i]['gt_classes'],
               'flipped': True}
      self.roidb.append(entry)
    self._image_index = self._image_index * 2

將之前的生成的roidb轉(zhuǎn)換成一個可以用來訓(xùn)練的roidb栗竖,添加了一些元數(shù)據(jù)

  • max_classes: 圖片的類別暑脆,由于使用的是ground truth,所以這里的類別就是物體自己的類別狐肢。
  • max_overlaps: 表示物體的框架與某個物體邊界框的重合度添吗,這里自己和自己重合了為1,不過我們下面rpn會用到它份名。
lib/roi_data_layer/roidb.py
def prepare_roidb(imdb):
  """Enrich the imdb's roidb by adding some derived quantities that
  are useful for training. This function precomputes the maximum
  overlap, taken over ground-truth boxes, between each ROI and
  each ground-truth box. The class with maximum overlap is also
  recorded.
  """
  roidb = imdb.roidb
  if not (imdb.name.startswith('coco')):
    sizes = [PIL.Image.open(imdb.image_path_at(i)).size
         for i in range(imdb.num_images)]
  for i in range(len(imdb.image_index)):
    roidb[i]['image'] = imdb.image_path_at(i)
    if not (imdb.name.startswith('coco')):
      roidb[i]['width'] = sizes[i][0]
      roidb[i]['height'] = sizes[i][1]
    # need gt_overlaps as a dense array for argmax
    gt_overlaps = roidb[i]['gt_overlaps'].toarray()
    # max overlap with gt over classes (columns)
    max_overlaps = gt_overlaps.max(axis=1)
    # gt class that had the max overlap
    max_classes = gt_overlaps.argmax(axis=1)
    roidb[i]['max_classes'] = max_classes
    roidb[i]['max_overlaps'] = max_overlaps
    # sanity checks
    # max overlap of 0 => class should be zero (background)
    zero_inds = np.where(max_overlaps == 0)[0]
    assert all(max_classes[zero_inds] == 0)
    # max overlap > 0 => class should not be zero (must be a fg class)
    nonzero_inds = np.where(max_overlaps > 0)[0]
    assert all(max_classes[nonzero_inds] != 0)

二碟联、特征提取

殘差神經(jīng)網(wǎng)絡(luò)推薦大家看一下這篇paper https://arxiv.org/abs/1512.03385
本文不深挖,只從代碼的實(shí)現(xiàn)去簡單的闡述一下這個網(wǎng)絡(luò)僵腺。

在計(jì)算機(jī)視覺圖像處理中鲤孵,普遍認(rèn)為隨著網(wǎng)絡(luò)的加深,我們可以獲取更加高級的“特征”辰如,然而梯度彌散/爆炸導(dǎo)致了深度的網(wǎng)絡(luò)無法收斂普监,雖然一些方法可以彌補(bǔ),例如Batch Normalization,選擇合適的激活函數(shù)凯正,梯度剪裁毙玻,使得可以收斂的網(wǎng)絡(luò)深度提升為原來的十倍,然而網(wǎng)絡(luò)性能卻開始退化廊散,反而導(dǎo)致了更大的誤差桑滩,如下圖所示:


  • 梯度彌散: 很大程度是來源于激活函數(shù)的“飽和”。在back propagation(反向傳播)的過程中需要計(jì)算激活函數(shù)的導(dǎo)數(shù)允睹,一旦卷積核的輸出落入函數(shù)的飽和區(qū)运准,它的梯度將會變得非常的小。使用反向傳播傳遞梯度的時候擂找,隨著傳播的深度加深戳吝,梯度的幅度會急劇變小,導(dǎo)致淺層的神經(jīng)網(wǎng)絡(luò)元的權(quán)重更新變得非常緩慢贯涎,學(xué)習(xí)效率就會降低听哭。
  • 梯度爆炸: 情況與梯度彌散相反,梯度在傳播過程中大幅增長塘雳,導(dǎo)致了在極端情況下權(quán)重變得非常大以致溢出陆盘,使網(wǎng)絡(luò)非常不穩(wěn)定。
  • 歸一化(Batch Normalization): 可以說是深度學(xué)習(xí)發(fā)展以來提出的最重要成果之一败明,還是推薦大家自己找資料隘马,仔細(xì)了解一些這個技術(shù)。簡而言之就是對每一層的輸出進(jìn)行規(guī)范化妻顶,讓均值和方差一致酸员,消除了權(quán)重帶來的放大和縮小的影響,這樣一來不僅解決了梯度彌散和爆炸問題讳嘱,還加快了網(wǎng)絡(luò)的收斂幔嗦,可以理解為BN將輸出從飽和區(qū)拉到了非飽和區(qū)。
  • 梯度剪切: 這個方案主要是針對梯度爆炸提出的沥潭,其思想是設(shè)置一個梯度剪切閾值邀泉,然后更新梯度的時候,如果梯度超過這個閾值钝鸽,那么就將其強(qiáng)制限制在這個范圍之內(nèi)汇恤。這可以防止梯度爆炸。

ResNet允許深度盡可能的加深拔恰,且不影響網(wǎng)絡(luò)的性能因谎,它的具體結(jié)構(gòu)如下:



特征提取網(wǎng)路我選擇的是深度殘差網(wǎng)絡(luò)(residual network), 層數(shù)為101。

tools/trainval_net.py
# load network
  if args.net == 'vgg16':
    net = vgg16()
  elif args.net == 'res50':
    net = resnetv1(num_layers=50)
  elif args.net == 'res101':
    net = resnetv1(num_layers=101)
  elif args.net == 'res152':
    net = resnetv1(num_layers=152)
  elif args.net == 'mobile':
    net = mobilenetv1()
  else:
    raise NotImplementedError

resnetv1是Network的一個實(shí)例颜懊,下面是初始化的代碼

  • Network 是特征提取網(wǎng)絡(luò)的基類蓝角,里面提供了一些成員來記錄Anchors,loss,layers,圖片,訓(xùn)練,回歸的一些信息, 還有一些訓(xùn)練神經(jīng)網(wǎng)絡(luò)常用的方法, 例如fully connection layer, max pooling layer, dropout layer 以及一些激活函數(shù)和loss function, 還有一些專門用于faster rcnn的函數(shù), region proposal, region classification, proposal layer等, 這個大家自己去看源碼阱穗。
    ResNet 101的實(shí)現(xiàn)如下(我打算十一月更新一篇文章專門介紹ResNet的實(shí)現(xiàn)饭冬,faster-rcnn-tf源碼里直接用的tensorflow.contrib.slim.python.slim.nets.resnet_v1.resnet_v1_block):
lib/nets/resnet_v1.py
class resnetv1(Network):
  def __init__(self, num_layers=50):
    Network.__init__(self)
    self._feat_stride = [16, ]
    self._feat_compress = [1. / float(self._feat_stride[0]), ]
    self._num_layers = num_layers
    self._scope = 'resnet_v1_%d' % num_layers
    self._decide_blocks()


  def _decide_blocks(self):
    # choose different blocks for different number of layers
    if self._num_layers == 50:
      self._blocks = [resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
                      resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
                      # use stride 1 for the last conv4 layer
                      resnet_v1_block('block3', base_depth=256, num_units=6, stride=1),
                      resnet_v1_block('block4', base_depth=512, num_units=3, stride=1)]

    elif self._num_layers == 101:
      self._blocks = [resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
                      resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
                      # use stride 1 for the last conv4 layer
                      resnet_v1_block('block3', base_depth=256, num_units=23, stride=1),
                      resnet_v1_block('block4', base_depth=512, num_units=3, stride=1)]

    elif self._num_layers == 152:
      self._blocks = [resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
                      resnet_v1_block('block2', base_depth=128, num_units=8, stride=2),
                      # use stride 1 for the last conv4 layer
                      resnet_v1_block('block3', base_depth=256, num_units=36, stride=1),
                      resnet_v1_block('block4', base_depth=512, num_units=3, stride=1)]

    else:
      # other numbers are not supported
      raise NotImplementedError

三使鹅、訓(xùn)練

當(dāng)我們準(zhǔn)備好訓(xùn)練數(shù)據(jù)和神經(jīng)網(wǎng)絡(luò)后,就可以開始我們的訓(xùn)練

tools/trainval_net.py
train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
            pretrained_model=args.weight,
            max_iters=args.max_iters)

首先過濾掉不符合要求的roidb, 然后建立一個包裝類SolverWrapper用于訓(xùn)練過程

lib/model/train_val.py
def train_net(network, imdb, roidb, valroidb, output_dir, tb_dir,
              pretrained_model=None,
              max_iters=40000):
  """Train a Faster R-CNN network."""
  roidb = filter_roidb(roidb)
  valroidb = filter_roidb(valroidb)

  tfconfig = tf.ConfigProto(allow_soft_placement=True)
  tfconfig.gpu_options.allow_growth = True

  with tf.Session(config=tfconfig) as sess:
    sw = SolverWrapper(sess, network, imdb, roidb, valroidb, output_dir, tb_dir,
                       pretrained_model=pretrained_model)
    print('Solving...')
    sw.train_model(sess, max_iters)
    print('done solving')


def filter_roidb(roidb):
  """Remove roidb entries that have no usable RoIs."""

  def is_valid(entry):
    # Valid images have:
    #   (1) At least one foreground RoI OR
    #   (2) At least one background RoI
    overlaps = entry['max_overlaps']
    # find boxes with sufficient overlap
    fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
    # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
    bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
                       (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
    # image is only valid if such boxes exist
    valid = len(fg_inds) > 0 or len(bg_inds) > 0
    return valid

  num = len(roidb)
  filtered_roidb = [entry for entry in roidb if is_valid(entry)]
  num_after = len(filtered_roidb)
  print('Filtered {} roidb entries: {} -> {}'.format(num - num_after,
                                                     num, num_after))
  return filtered_roidb

接下來開始訓(xùn)練模型

lib/model/train_val.py

  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Construct the computation graph
    lr, train_op = self.construct_graph(sess)

    # Find previous snapshots if there is any to restore from
    lsf, nfiles, sfiles = self.find_previous()

    # Initialize the variables or restore them from the last snapshot
    if lsf == 0:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
    else:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, 
                                                                            str(sfiles[-1]), 
                                                                            str(nfiles[-1]))
    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    # Make sure the lists are not empty
    stepsizes.append(max_iters)
    stepsizes.reverse()
    next_stepsize = stepsizes.pop()
    while iter < max_iters + 1:
      # Learning rate
      if iter == next_stepsize + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        rate *= cfg.TRAIN.GAMMA
        sess.run(tf.assign(lr, rate))
        next_stepsize = stepsizes.pop()

      timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
          self.net.train_step_with_summary(sess, blobs, train_op)
        self.writer.add_summary(summary, float(iter))
        # Also check the summary on the validation set
        blobs_val = self.data_layer_val.forward()
        summary_val = self.net.get_summary(sess, blobs_val)
        self.valwriter.add_summary(summary_val, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
          self.net.train_step(sess, blobs, train_op)
      timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))

      # Snapshotting
      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        ss_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(ss_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          self.remove_snapshot(np_paths, ss_paths)

      iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)

    self.writer.close()
    self.valwriter.close()

分下來開看這段代碼, 我們可以總結(jié)出4個階段:

  1. 建立數(shù)據(jù)輸入層用于訓(xùn)練和校驗(yàn)(validation)
  2. 構(gòu)建計(jì)算架構(gòu)(computation graph)
  3. 讀取訓(xùn)練參數(shù)
  4. 常規(guī)訓(xùn)練流程(更新參數(shù)昌抠,更新學(xué)習(xí)率患朱,獲取下一個訓(xùn)練batch...)

第一個階段

類ROIDataLayer是輸入ground truth roi的層,初始化的時候先打亂一下數(shù)據(jù)集的順序(主要用于校驗(yàn)集validation set)炊苫。

# Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)
lib/roi_data_layer/layer.py
class RoIDataLayer(object):
  """Fast R-CNN data layer used for training."""

  def __init__(self, roidb, num_classes, random=False):
    """Set the roidb to be used by this layer during training."""
    self._roidb = roidb
    self._num_classes = num_classes
    # Also set a random flag
    self._random = random
    self._shuffle_roidb_inds()

  def _shuffle_roidb_inds(self):
    """Randomly permute the training roidb."""
    # If the random flag is set, 
    # then the database is shuffled according to system time
    # Useful for the validation set
    if self._random:
      st0 = np.random.get_state()
      millis = int(round(time.time() * 1000)) % 4294967295
      np.random.seed(millis)
    
    if cfg.TRAIN.ASPECT_GROUPING:
      widths = np.array([r['width'] for r in self._roidb])
      heights = np.array([r['height'] for r in self._roidb])
      horz = (widths >= heights)
      vert = np.logical_not(horz)
      horz_inds = np.where(horz)[0]
      vert_inds = np.where(vert)[0]
      inds = np.hstack((
          np.random.permutation(horz_inds),
          np.random.permutation(vert_inds)))
      inds = np.reshape(inds, (-1, 2))
      row_perm = np.random.permutation(np.arange(inds.shape[0]))
      inds = np.reshape(inds[row_perm, :], (-1,))
      self._perm = inds
    else:
      self._perm = np.random.permutation(np.arange(len(self._roidb)))
    # Restore the random state
    if self._random:
      np.random.set_state(st0)
      
    self._cur = 0

ROIDataLayer最重要的方法是forward,它的作用就是獲取下一個最小批裁厅,forward內(nèi)調(diào)用_get_next_minibatch把mini batch小批中的rio[i]的邊界框信息提取出來,用一個字典來保存, 詳細(xì)代碼如下:

lib/roi_data_layer/layer.py
  def _get_next_minibatch_inds(self):
    """Return the roidb indices for the next minibatch."""
    
    if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):
      self._shuffle_roidb_inds()

    db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]
    self._cur += cfg.TRAIN.IMS_PER_BATCH

    return db_inds

  def _get_next_minibatch(self):
    """Return the blobs to be used for the next minibatch.

    If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a
    separate process and made available through self._blob_queue.
    """
    db_inds = self._get_next_minibatch_inds()
    minibatch_db = [self._roidb[i] for i in db_inds]
    return get_minibatch(minibatch_db, self._num_classes)
      
  def forward(self):
    """Get blobs and copy them into this layer's top blob vector."""
    blobs = self._get_next_minibatch()
    return blobs

_get_next_minibatch內(nèi)調(diào)用_get_next_minibatch_inds獲取下一個最小批的index, 當(dāng)前+1, 如果超過數(shù)據(jù)集最大長度,則先"洗牌", 然后從第一張開始, 與stochastic gradient descent(機(jī)器學(xué)習(xí)中的一種迭代算法)的流程類似, 只不過我們這里的最小批為1(cfg.TRAIN.IMS_PER_BATCH = 1)侨艾,即一次訓(xùn)練一個目標(biāo)區(qū)域执虹。程序最后調(diào)用get_minibatch方法來構(gòu)建一個最小批伴奥,返回的數(shù)據(jù)是一個dictionary

  • data: 形狀為 [num_img, h, w, 3]. 神經(jīng)網(wǎng)絡(luò)的圖像輸入, 此處的 num_img=1
  • gt_boxes: 邊界框, 形狀為 [num_gt_indexs, 5] 前4列為邊界框坐標(biāo)信息香嗓,第5列為目標(biāo)類別信息
  • im_info: 圖像信息, 三個元素, 前兩個元素為最終輸入圖像的高和寬, 第三個元素為原始圖像縮放為當(dāng)前圖像的比例
lib/roi_data_layer/minibatch.py
def get_minibatch(roidb, num_classes):
  """Given a roidb, construct a minibatch sampled from it."""
  num_images = len(roidb)
  # Sample random scales to use for each image in this batch
  random_scale_inds = npr.randint(0, high=len(cfg.TRAIN.SCALES),
                  size=num_images)
  assert(cfg.TRAIN.BATCH_SIZE % num_images == 0), \
    'num_images ({}) must divide BATCH_SIZE ({})'. \
    format(num_images, cfg.TRAIN.BATCH_SIZE)

  # Get the input image blob, formatted for caffe
  im_blob, im_scales = _get_image_blob(roidb, random_scale_inds)

  blobs = {'data': im_blob}

  assert len(im_scales) == 1, "Single batch only"
  assert len(roidb) == 1, "Single batch only"
  
  # gt boxes: (x1, y1, x2, y2, cls)
  if cfg.TRAIN.USE_ALL_GT:
    # Include all ground truth boxes
    gt_inds = np.where(roidb[0]['gt_classes'] != 0)[0]
  else:
    # For the COCO ground truth boxes, exclude the ones that are ''iscrowd'' 
    gt_inds = np.where(roidb[0]['gt_classes'] != 0 & np.all(roidb[0]['gt_overlaps'].toarray() > -1.0, axis=1))[0]
  gt_boxes = np.empty((len(gt_inds), 5), dtype=np.float32)
  gt_boxes[:, 0:4] = roidb[0]['boxes'][gt_inds, :] * im_scales[0]
  gt_boxes[:, 4] = roidb[0]['gt_classes'][gt_inds]
  blobs['gt_boxes'] = gt_boxes
  blobs['im_info'] = np.array(
    [im_blob.shape[1], im_blob.shape[2], im_scales[0]],
    dtype=np.float32)

  return blobs

def _get_image_blob(roidb, scale_inds):
  """Builds an input blob from the images in the roidb at the specified
  scales.
  """
  num_images = len(roidb)
  processed_ims = []
  im_scales = []
  for i in range(num_images):
    im = cv2.imread(roidb[i]['image'])
    if roidb[i]['flipped']:
      im = im[:, ::-1, :]
    target_size = cfg.TRAIN.SCALES[scale_inds[i]]
    im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, target_size,
                    cfg.TRAIN.MAX_SIZE)
    im_scales.append(im_scale)
    processed_ims.append(im)

  # Create a blob to hold the input images
  blob = im_list_to_blob(processed_ims)

  return blob, im_scales
  • im_list_to_blob: 將圖片轉(zhuǎn)成像素信息導(dǎo)入輸出層
  • prep_im_for_blob: 平均縮小以及縮放一張圖片
lib/utils/blob.py
def im_list_to_blob(ims):
  """Convert a list of images into a network input.

  Assumes images are already prepared (means subtracted, BGR order, ...).
  """
  max_shape = np.array([im.shape for im in ims]).max(axis=0)
  num_images = len(ims)
  blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),
                  dtype=np.float32)
  for i in range(num_images):
    im = ims[i]
    blob[i, 0:im.shape[0], 0:im.shape[1], :] = im

  return blob


def prep_im_for_blob(im, pixel_means, target_size, max_size):
  """Mean subtract and scale an image for use in a blob."""
  im = im.astype(np.float32, copy=False)
  im -= pixel_means
  im_shape = im.shape
  im_size_min = np.min(im_shape[0:2])
  im_size_max = np.max(im_shape[0:2])
  im_scale = float(target_size) / float(im_size_min)
  # Prevent the biggest axis from being more than MAX_SIZE
  if np.round(im_scale * im_size_max) > max_size:
    im_scale = float(max_size) / float(im_size_max)
  im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale,
                  interpolation=cv2.INTER_LINEAR)

  return im, im_scale

第二個階段

第二個階段是整個faster rcnn的核心部分沐鼠,包括了PRN, ROI pooling以及最終的object classification和bounding box regression仁期,我打算放在另一篇文章講狂巢,所以先跳過這段委造,讓我們來看整個訓(xùn)練的流程虎谢。
Faster RCNN源碼解析(2).

第三個階段

參數(shù)讀取主要涉及兩個函數(shù)self.initialize(sess)氏涩,self.restore(sess, str(sfiles[-1]), str(nfiles[-1]))

  • initialize: 預(yù)訓(xùn)練的模型中載入權(quán)重蚁鳖。首先調(diào)用get_variables_in_checkpoint_file方法從checkpoint file中讀取參數(shù), 然后調(diào)用Network(resnetv1)中的get_variables_to_restore方法過濾一遍參數(shù)磺芭,忽略第一層固定的參數(shù),接下來調(diào)用tensorflow中的方法tf.train.Saver保存參數(shù)醉箕,接下來調(diào)用fix_variables在導(dǎo)入模型前固定參數(shù)钾腺。
  • restore: 從檢查點(diǎn)載入權(quán)重和學(xué)習(xí)率。首先從self.saver中讀取checkpoint file, 然后還原到最近的快照讥裤,如果發(fā)現(xiàn)已經(jīng)訓(xùn)練完一輪放棒,就更行一下學(xué)習(xí)率 rate *= cfg.TRAIN.GAMMA(0.1)
lib/model/train_val.py
  def initialize(self, sess):
    # Initial file lists are empty
    np_paths = []
    ss_paths = []
    # Fresh train directly from ImageNet weights
    print('Loading initial model weights from {:s}'.format(self.pretrained_model))
    variables = tf.global_variables()
    # Initialize all variables first
    sess.run(tf.variables_initializer(variables, name='init'))
    var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
    # Get the variables to restore, ignoring the variables to fix
    variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, self.pretrained_model)
    print('Loaded.')
    # Need to fix the variables before loading, so that the RGB weights are changed to BGR
    # For VGG16 it also changes the convolutional weights fc6 and fc7 to
    # fully connected weights
    self.net.fix_variables(sess, self.pretrained_model)
    print('Fixed.')
    last_snapshot_iter = 0
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = list(cfg.TRAIN.STEPSIZE)

    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def restore(self, sess, sfile, nfile):
    # Get the most recent snapshot and restore
    np_paths = [nfile]
    ss_paths = [sfile]
    # Restore model from snapshots
    last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
    # Set the learning rate
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      if last_snapshot_iter > stepsize:
        rate *= cfg.TRAIN.GAMMA
      else:
        stepsizes.append(stepsize)

    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def get_variables_in_checkpoint_file(self, file_name):
    try:
      reader = pywrap_tensorflow.NewCheckpointReader(file_name)
      var_to_shape_map = reader.get_variable_to_shape_map()
      return var_to_shape_map 
    except Exception as e:  # pylint: disable=broad-except
      print(str(e))
      if "corrupted compressed block contents" in str(e):
        print("It's likely that your checkpoint file has been compressed "
              "with SNAPPY.")

 ########## lib/nets/resnet_v1.py ###################################
  def get_variables_to_restore(self, variables, var_keep_dic):
    variables_to_restore = []

    for v in variables:
      # exclude the first conv layer to swap RGB to BGR
      if v.name == (self._scope + '/conv1/weights:0'):
        self._variables_to_fix[v.name] = v
        continue
      if v.name.split(':')[0] in var_keep_dic:
        print('Variables restored: %s' % v.name)
        variables_to_restore.append(v)

    return variables_to_restore
######################################################################

  def from_snapshot(self, sess, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.saver.restore(sess, sfile)
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter

第四個階段

迭代訓(xùn)練, 保存檢查點(diǎn), 并且在中間輸出,在訓(xùn)練完成后, 保存最后一步的輸出坞琴。
代碼很簡單哨查,調(diào)用tensorflow的session.run,來運(yùn)行之前構(gòu)造的網(wǎng)絡(luò),執(zhí)行計(jì)算圖的操作剧辐。

lib/nets/network.py
  def train_step_with_summary(self, sess, blobs, train_op):
    feed_dict = {self._image: blobs['data'], self._im_info: blobs['im_info'],
                 self._gt_boxes: blobs['gt_boxes']}
    rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss, summary, _ = sess.run([self._losses["rpn_cross_entropy"],
                                                                                 self._losses['rpn_loss_box'],
                                                                                 self._losses['cross_entropy'],
                                                                                 self._losses['loss_box'],
                                                                                 self._losses['total_loss'],
                                                                                 self._summary_op,
                                                                                 train_op],
                                                                                feed_dict=feed_dict)
    return rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss, summary

  def train_step(self, sess, blobs, train_op):
    feed_dict = {self._image: blobs['data'], self._im_info: blobs['im_info'],
                 self._gt_boxes: blobs['gt_boxes']}
    rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss, _ = sess.run([self._losses["rpn_cross_entropy"],
                                                                        self._losses['rpn_loss_box'],
                                                                        self._losses['cross_entropy'],
                                                                        self._losses['loss_box'],
                                                                        self._losses['total_loss'],
                                                                        train_op],
                                                                       feed_dict=feed_dict)
    return rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss

  def get_summary(self, sess, blobs):
    feed_dict = {self._image: blobs['data'], self._im_info: blobs['im_info'],
                 self._gt_boxes: blobs['gt_boxes']}
    summary = sess.run(self._summary_op_val, feed_dict=feed_dict)

    return summary

至此寒亥,就是我們整個的訓(xùn)練過程,校驗(yàn)和測試代碼就不做介紹了荧关,因?yàn)槌艘恍┘?xì)微的操作溉奕,過程跟一般的神經(jīng)網(wǎng)絡(luò)訓(xùn)練流程是一樣。
Faster RCNN源碼解析(2).

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末忍啤,一起剝皮案震驚了整個濱河市加勤,隨后出現(xiàn)的幾起案子仙辟,更是在濱河造成了極大的恐慌,老刑警劉巖鳄梅,帶你破解...
    沈念sama閱讀 218,941評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件叠国,死亡現(xiàn)場離奇詭異,居然都是意外死亡戴尸,警方通過查閱死者的電腦和手機(jī)粟焊,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,397評論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來孙蒙,“玉大人项棠,你說我怎么就攤上這事】媛停” “怎么了香追?”我有些...
    開封第一講書人閱讀 165,345評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長坦胶。 經(jīng)常有香客問我透典,道長,這世上最難降的妖魔是什么迁央? 我笑而不...
    開封第一講書人閱讀 58,851評論 1 295
  • 正文 為了忘掉前任掷匠,我火速辦了婚禮,結(jié)果婚禮上岖圈,老公的妹妹穿的比我還像新娘讹语。我一直安慰自己,他們只是感情好蜂科,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,868評論 6 392
  • 文/花漫 我一把揭開白布顽决。 她就那樣靜靜地躺著,像睡著了一般导匣。 火紅的嫁衣襯著肌膚如雪才菠。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,688評論 1 305
  • 那天贡定,我揣著相機(jī)與錄音赋访,去河邊找鬼。 笑死缓待,一個胖子當(dāng)著我的面吹牛蚓耽,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播旋炒,決...
    沈念sama閱讀 40,414評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼步悠,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了瘫镇?” 一聲冷哼從身側(cè)響起鼎兽,我...
    開封第一講書人閱讀 39,319評論 0 276
  • 序言:老撾萬榮一對情侶失蹤答姥,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后谚咬,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體鹦付,經(jīng)...
    沈念sama閱讀 45,775評論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,945評論 3 336
  • 正文 我和宋清朗相戀三年序宦,在試婚紗的時候發(fā)現(xiàn)自己被綠了睁壁。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,096評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡互捌,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出行剂,到底是詐尸還是另有隱情秕噪,我是刑警寧澤,帶...
    沈念sama閱讀 35,789評論 5 346
  • 正文 年R本政府宣布厚宰,位于F島的核電站腌巾,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏铲觉。R本人自食惡果不足惜澈蝙,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,437評論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望撵幽。 院中可真熱鬧灯荧,春花似錦、人聲如沸盐杂。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,993評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽链烈。三九已至厉斟,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間强衡,已是汗流浹背擦秽。 一陣腳步聲響...
    開封第一講書人閱讀 33,107評論 1 271
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留漩勤,地道東北人感挥。 一個月前我還...
    沈念sama閱讀 48,308評論 3 372
  • 正文 我出身青樓,卻偏偏與公主長得像锯七,于是被迫代替她去往敵國和親链快。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,037評論 2 355

推薦閱讀更多精彩內(nèi)容