Faster R-CNN tensorflow代碼詳解

研究背景

根據(jù)Faster-RCNN算法的運行和調(diào)試情況误澳,對代碼進行深入分析偿乖。

參考資料

Faster R-CNN:tf-faster-rcnn代碼結(jié)構(gòu)
分析參考1
分析參考2
Faster RCNN整體流程
Faster RCNN算法詳解
Faster-Rcnn demo.py解析
Faster R-CNN的訓(xùn)練過程的理解

各部分代碼分析

1 編譯Cython模塊

cd tf-faster-rcnn/lib  # 首先進入目錄Faster-RCNN_TF/lib
make clean
make  #編譯

編譯成功之后,目錄tf-faster-rcnn/lib/nms 和tf-faster-rcnn/lib/roi_pooling_layer/ 和tf-faster-rcnn/lib/utils下面會出現(xiàn)一些.so文件堕战。
注意:.so文件不具可移植到性方淤,因為編譯生成的文件是只適應(yīng)本臺計算機的媚创,換一臺計算機之后竞膳,用原來的.so文件程序會出錯航瞭。并且,必須要先刪除舊的.so文件make clean坦辟,否則就會調(diào)用舊的.so文件刊侯,而不生成新的.so文件。重新運行程序的時候锉走,要先刪除這幾個.so文件滨彻,并重新進行編譯。

2 pascal_voc數(shù)據(jù)集的數(shù)據(jù)讀寫接口

2.1 工程文件tf-faster-rcnn中讀取數(shù)據(jù)的接口都在目錄tf-faster-rcnn/lib/datasets下挪蹭。共有2種數(shù)據(jù)來訓(xùn)練網(wǎng)絡(luò)亭饵,分別是pascal_voc和coco,數(shù)據(jù)讀寫接口分別是tf-faster-rcnn/lib/datasets中的pascal_voc.py和coco.py梁厉。

工程主要用到的是目錄Annotations中的XML文件辜羊、目錄JPEGImages中的圖片、目錄ImageSets/Layout中的txt文件懂算。
目錄下其他文件:
factory.py:是個工廠類只冻,用類生成imdb類并且返回數(shù)據(jù)庫供網(wǎng)絡(luò)訓(xùn)練和測試使用庇麦;
imdb.py:是數(shù)據(jù)庫讀寫類的基類计技,分裝了許多db的操作,具體的一些文件讀寫需要繼承繼續(xù)讀寫山橄。

VOCdevkit/
VOCdevkit/VOC2007/
VOCdevkit/VOC2007/Annotations #所有圖片的XML文件垮媒,一張圖片對應(yīng)一個XML文件,XML文件中給出的圖片gt的形式是左上角和右下角的坐標(biāo)
VOCdevkit/VOC2007/ImageSets/
VOCdevkit/VOC2007/ImageSets/Layout #里面有三個txt文件航棱,分別是train.txt,trainval.txt,val.txt,存儲的分別是訓(xùn)練圖片的名字列表睡雇,訓(xùn)練驗證集的圖片名字列表,驗證集圖片的名字列表(名字均沒有.jpg后綴)
VOCdevkit/VOC2007/ImageSets/Main
VOCdevkit/VOC2007/ImageSets/Segmentation
VOCdevkit/VOC2007/JPEGImages #所有的圖片*.jpg
VOCdevkit/VOC2007/SegmentationClass #segmentations by class
VOCdevkit/VOC2007/SegmentationObject #segmentations by object

2.2 pascal_voc的數(shù)據(jù)讀寫接口

  • 主函數(shù) if name == ‘main’在文件pascal_voc.py的最下面
if __name__ == '__main__':
    from datasets.pascal_voc import pascal_voc
    d = pascal_voc('trainval', '2007')   #pascal_voc是一個類
    res = d.roidb
    from IPython import embed; 
    embed()
  • 主函數(shù)中的類 pascal_voc代碼饮醇,在文件pascal_voc.py的最上面:
class pascal_voc(imdb):
  def __init__(self, image_set, year, use_diff=False):
    name = 'voc_' + year + '_' + image_set
    if use_diff:
      name += '_diff'
    imdb.__init__(self, name)
    self._year = year
    self._image_set = image_set
    self._devkit_path = self._get_default_path()
    self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
    self._classes = ('__background__',  # always index 0
                     'aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor')
    self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))
    self._image_ext = '.jpg'
    self._image_index = self._load_image_set_index()
    # Default to roidb handler
    self._roidb_handler = self.gt_roidb
    self._salt = str(uuid.uuid4())
    self._comp_id = 'comp4'

    # PASCAL specific config options
    self.config = {'cleanup': True,
                   'use_salt': True,
                   'use_diff': use_diff,
                   'matlab_eval': False,
                   'rpn_file': None}

    assert os.path.exists(self._devkit_path), \
      'VOCdevkit path does not exist: {}'.format(self._devkit_path)
    assert os.path.exists(self._data_path), \
      'Path does not exist: {}'.format(self._data_path)

  def image_path_at(self, i):
    """
    Return the absolute path to image i in the image sequence.
    """
    return self.image_path_from_index(self._image_index[i])

  def image_path_from_index(self, index):
    """
    Construct an image path from the image's "index" identifier.
    """
    image_path = os.path.join(self._data_path, 'JPEGImages',
                              index + self._image_ext)
    assert os.path.exists(image_path), \
      'Path does not exist: {}'.format(image_path)
    return image_path

  def _load_image_set_index(self):
    """
    Load the indexes listed in this dataset's image set file.
    """
    # Example path to image set file:
    # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
    image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                  self._image_set + '.txt')
    assert os.path.exists(image_set_file), \
      'Path does not exist: {}'.format(image_set_file)
    with open(image_set_file) as f:
      image_index = [x.strip() for x in f.readlines()]
    return image_index

  def _get_default_path(self):
    """
    Return the default path where PASCAL VOC is expected to be installed.
    """
    return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)

  def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
    if os.path.exists(cache_file):
      with open(cache_file, 'rb') as fid:
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb

    gt_roidb = [self._load_pascal_annotation(index)
                for index in self.image_index]
    with open(cache_file, 'wb') as fid:
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb

  def rpn_roidb(self):
    if int(self._year) == 2007 or self._image_set != 'test':
      gt_roidb = self.gt_roidb()
      rpn_roidb = self._load_rpn_roidb(gt_roidb)
      roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
    else:
      roidb = self._load_rpn_roidb(None)

    return roidb

  def _load_rpn_roidb(self, gt_roidb):
    filename = self.config['rpn_file']
    print('loading {}'.format(filename))
    assert os.path.exists(filename), \
      'rpn data not found at: {}'.format(filename)
    with open(filename, 'rb') as f:
      box_list = pickle.load(f)
    return self.create_roidb_from_box_list(box_list, gt_roidb)

  def _load_pascal_annotation(self, index):
    """
    Load image and bounding boxes info from XML file in the PASCAL VOC
    format.
    """
    filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
    tree = ET.parse(filename)
    objs = tree.findall('object')
    if not self.config['use_diff']:
      # Exclude the samples labeled as difficult
      non_diff_objs = [
        obj for obj in objs if int(obj.find('difficult').text) == 0]
      # if len(non_diff_objs) != len(objs):
      #     print 'Removed {} difficult objects'.format(
      #         len(objs) - len(non_diff_objs))
      objs = non_diff_objs
    num_objs = len(objs)

    boxes = np.zeros((num_objs, 4), dtype=np.uint16)
    gt_classes = np.zeros((num_objs), dtype=np.int32)
    overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
    # "Seg" area for pascal is just the box area
    seg_areas = np.zeros((num_objs), dtype=np.float32)

    # Load object bounding boxes into a data frame.
    for ix, obj in enumerate(objs):
      bbox = obj.find('bndbox')
      # Make pixel indexes 0-based
      x1 = float(bbox.find('xmin').text) - 1
      y1 = float(bbox.find('ymin').text) - 1
      x2 = float(bbox.find('xmax').text) - 1
      y2 = float(bbox.find('ymax').text) - 1
      cls = self._class_to_ind[obj.find('name').text.lower().strip()]
      boxes[ix, :] = [x1, y1, x2, y2]
      gt_classes[ix] = cls
      overlaps[ix, cls] = 1.0
      seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

    overlaps = scipy.sparse.csr_matrix(overlaps)

    return {'boxes': boxes,
            'gt_classes': gt_classes,
            'gt_overlaps': overlaps,
            'flipped': False,
            'seg_areas': seg_areas}

  def _get_comp_id(self):
    comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
               else self._comp_id)
    return comp_id

  def _get_voc_results_file_template(self):
    # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt
    filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
    path = os.path.join(
      self._devkit_path,
      'results',
      'VOC' + self._year,
      'Main',
      filename)
    return path

  def _write_voc_results_file(self, all_boxes):
    for cls_ind, cls in enumerate(self.classes):
      if cls == '__background__':
        continue
      print('Writing {} VOC results file'.format(cls))
      filename = self._get_voc_results_file_template().format(cls)
      with open(filename, 'wt') as f:
        for im_ind, index in enumerate(self.image_index):
          dets = all_boxes[cls_ind][im_ind]
          if dets == []:
            continue
          # the VOCdevkit expects 1-based indices
          for k in range(dets.shape[0]):
            f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                    format(index, dets[k, -1],
                           dets[k, 0] + 1, dets[k, 1] + 1,
                           dets[k, 2] + 1, dets[k, 3] + 1))

  def _do_python_eval(self, output_dir='output'):
    annopath = os.path.join(
      self._devkit_path,
      'VOC' + self._year,
      'Annotations',
      '{:s}.xml')
    imagesetfile = os.path.join(
      self._devkit_path,
      'VOC' + self._year,
      'ImageSets',
      'Main',
      self._image_set + '.txt')
    cachedir = os.path.join(self._devkit_path, 'annotations_cache')
    aps = []
    # The PASCAL VOC metric changed in 2010
    use_07_metric = True if int(self._year) < 2010 else False
    print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
    if not os.path.isdir(output_dir):
      os.mkdir(output_dir)
    for i, cls in enumerate(self._classes):
      if cls == '__background__':
        continue
      filename = self._get_voc_results_file_template().format(cls)
      rec, prec, ap = voc_eval(
        filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
        use_07_metric=use_07_metric, use_diff=self.config['use_diff'])
      aps += [ap]
      print(('AP for {} = {:.4f}'.format(cls, ap)))
      with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
        pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
    print(('Mean AP = {:.4f}'.format(np.mean(aps))))
    print('~~~~~~~~')
    print('Results:')
    for ap in aps:
      print(('{:.3f}'.format(ap)))
    print(('{:.3f}'.format(np.mean(aps))))
    print('~~~~~~~~')
    print('')
    print('--------------------------------------------------------------')
    print('Results computed with the **unofficial** Python eval code.')
    print('Results should be very close to the official MATLAB eval code.')
    print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
    print('-- Thanks, The Management')
    print('--------------------------------------------------------------')

  def _do_matlab_eval(self, output_dir='output'):
    print('-----------------------------------------------------')
    print('Computing results with the official MATLAB eval code.')
    print('-----------------------------------------------------')
    path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
                        'VOCdevkit-matlab-wrapper')
    cmd = 'cd {} && '.format(path)
    cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)
    cmd += '-r "dbstop if error; '
    cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \
      .format(self._devkit_path, self._get_comp_id(),
              self._image_set, output_dir)
    print(('Running:\n{}'.format(cmd)))
    status = subprocess.call(cmd, shell=True)

  def evaluate_detections(self, all_boxes, output_dir):
    self._write_voc_results_file(all_boxes)
    self._do_python_eval(output_dir)
    if self.config['matlab_eval']:
      self._do_matlab_eval(output_dir)
    if self.config['cleanup']:
      for cls in self._classes:
        if cls == '__background__':
          continue
        filename = self._get_voc_results_file_template().format(cls)
        os.remove(filename)

  def competition_mode(self, on):
    if on:
      self.config['use_salt'] = False
      self.config['cleanup'] = False
    else:
      self.config['use_salt'] = True
      self.config['cleanup'] = True
  • init是初始化函數(shù)它抱,對應(yīng)著的是pascal_voc的數(shù)據(jù)集訪問格式
  def __init__(self, image_set, year, use_diff=False):
    name = 'voc_' + year + '_' + image_set
    if use_diff:
      name += '_diff'
    imdb.__init__(self, name)  #繼承了類imdb的初始化函數(shù)__init__(),傳進去的參數(shù)是voc_2007_train朴艰。類imdb在lib/datasets/imdb.py里面被定義
    self._year = year #是一個str观蓄,是VOC數(shù)據(jù)的年份混移,值是'2007'或者'2012',以2007為例
    self._image_set = image_set #是一個str侮穿,值是'train'或者'test'或者'trainval'或者'val'歌径,表示的意思是用(訓(xùn)練集)或者(測試集)或者(訓(xùn)練驗證集)或者(驗證集)里面的數(shù)據(jù),以train為例
    self._devkit_path = self._get_default_path() #調(diào)用def _get_default_path(self)        路徑data/VOCdevkit/VOC2007
    self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)#VOC2007
    self._classes = ('__background__',  # always index 0
                     'aeroplane', 'bicycle', 'bird', 'boat',  
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor') 
    #數(shù)據(jù)集中所包含的全部的object類別
    self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))  
# 構(gòu)建字典{'__background__':'0','aeroplane':'1', 'bicycle':'2', 'bird':'3', 'boat':'4','bottle':'5', 'bus':'6', 'car':'7', 'cat':'8', 'chair':'9','cow':'10', 'diningtable':'11', 'dog':'12', 'horse':'13','motorbike':'14', 'person':'15', 'pottedplant':'16','sheep':'17', 'sofa':'18', 'train':'19', 'tvmonitor':'20'}  self.num_classes是object的類別總數(shù)21(背景background也算一類)亲茅,這個函數(shù)繼承自lib/datasets/imdb.py
    self._image_ext = '.jpg'  # 圖片后綴名
    self._image_index = self._load_image_set_index()  #加載了樣本的list文件
    # Default to roidb handler
    self._roidb_handler = self.gt_roidb   # 當(dāng)有RPN的時候,讀取并返回圖片gt的db回铛。函數(shù)gt_roidb里面并沒有提取圖片的ROI,因為faster-rcnn有RPN克锣,用RPN來提取ROI茵肃。函數(shù)gt_roidb返回的是圖片的gt。(fast-rcnn沒有RPN)
    self._salt = str(uuid.uuid4())
    self._comp_id = 'comp4'

    # PASCAL specific config options
    self.config = {'cleanup': True,
                   'use_salt': True,
                   'use_diff': use_diff,
                   'matlab_eval': False,
                   'rpn_file': None}

    assert os.path.exists(self._devkit_path), \
      'VOCdevkit path does not exist: {}'.format(self._devkit_path) #如果路徑self._devkit_path(也就是目錄VOCdevkit)不存在袭祟,退出
    assert os.path.exists(self._data_path), \
      'Path does not exist: {}'.format(self._data_path)#如果路徑self._data_path(也就是VOCdevkit/VOC2007)不存在免姿,退出
  • 子函數(shù)def _get_default_path(self)
   def _get_default_path(self):
    """
    Return the default path where PASCAL VOC is expected to be installed.
返回數(shù)據(jù)集pascal_voc的默認(rèn)路徑:tf-faster-rcnn/data/VOCdevkit/2007
    """
    return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)#cfg.DATA_DIR是在tf-faster-rcnn/lib/model/config.py里面定義

tf-faster-rcnn/lib/model/config.py中定義DATA_DIR的地方是這樣的(在257-261行):

# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))

# Data directory
__C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data'))
  • 子函數(shù)def _load_image_set_index(self)
  def _load_image_set_index(self):
    """
    Load the indexes listed in this dataset's image set file.
    """
    # Example path to image set file:
    # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
    image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                  self._image_set + '.txt')
 # image_set_file就是tf-faster-rcnn/data/VOCdevkit2007/VOC2007/ImageSets/Layout/train.txt
 #之所以要讀這個train.txt文件,是因為train.txt文件里面寫的是集合train中所有圖片的名字(沒有后綴.jpg)
    assert os.path.exists(image_set_file), \
      'Path does not exist: {}'.format(image_set_file)
    with open(image_set_file) as f:  # 讀上面的train.txt文件
      image_index = [x.strip() for x in f.readlines()]  #將train.txt的內(nèi)容(圖片名字)讀取出來放在image_index里面
    return image_index  #得到image_set里面所有圖片的名字(沒有后綴.jpg)

得到一個list,這個list里面是集合self._image_set中所有圖片的名字(注意榕酒,圖片名字沒有后綴.jpg)

  • 子函數(shù)def gt_roidb(self)
  def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')   
#給.pkl文件起個名字胚膊。參數(shù)self.cache_path和self.name繼承自類imdb,類imdb在lib/datasets/imdb.py里面定義
    if os.path.exists(cache_file):  # 如果這個.pkl文件存在(說明之前執(zhí)行過本函數(shù),生成了這個pkl文件)即預(yù)處理模型pretrain model
      with open(cache_file, 'rb') as fid:  #打開
        try:
          roidb = pickle.load(fid)
        except:
          roidb = pickle.load(fid, encoding='bytes')  #將里面的數(shù)據(jù)加載進來
      print('{} gt roidb loaded from {}'.format(self.name, cache_file))
      return roidb  #返回

    gt_roidb = [self._load_pascal_annotation(index)    # 如果這個.pkl文件不存在想鹰,說明是第一次執(zhí)行本函數(shù)紊婉。
                for index in self.image_index]  #那么首先要做的就是獲取圖片的gt,函數(shù)_load_pascal_annotation的作用就是獲取圖片gt辑舷。
    with open(cache_file, 'wb') as fid:    pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)  #將圖片的gt保存在.pkl文件里面
    print('wrote gt roidb to {}'.format(cache_file))

    return gt_roidb

讀取并返回圖片gt的db喻犁。這個函數(shù)就是將圖片的gt加載進來。
其中何缓,pascal_voc圖片的gt信息在XML文件中肢础;并且,圖片的gt被提前放在了一個.pkl文件里面碌廓。(這個.pkl文件需要我們自己生成传轰,代碼就在該函數(shù)中)之所以會將圖片的gt提前放在一個.pkl文件里面,是為了不用每次都再重新讀圖片的gt谷婆,直接加載這個文件就可以了慨蛙,可以提升速度。
參數(shù)self.cache_path和self.name繼承自類imdb,類imdb在tf-faster-rcnn/lib/datasets/imdb.py里面被定義纪挎。類imdb中定義函數(shù)self.cache_path的地方在imdb.py中的77-82行:

@property
  def name(self):
    return self._name

@property
  def cache_path(self):
    cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache'))
    if not os.path.exists(cache_path):
      os.makedirs(cache_path)
    return cache_path

類imdb中定義函數(shù)self.name的地方在imdb.py中的23-35行:

  def __init__(self, name, classes=None):  #是類imdb的初始化函數(shù)期贫,在pascal_voc.py被用到
    self._name = name   # name是形參,傳進來的參數(shù)是'voc_2007_train' or ‘voc_2007_test’ or 'voc_2007_val' or 'voc_2007_trainval'
    self._num_classes = 0
    if not classes:
      self._classes = []  #類imdb中定義函數(shù)self.name的地方
    else:
      self._classes = classes
    self._image_index = []
    self._obj_proposer = 'gt'
    self._roidb = None
    self._roidb_handler = self.default_roidb
    # Use this dict for storing dataset specific config options
    self.config = {}

@property
  def name(self):  #類imdb中定義函數(shù)self.name的地方
    return self._name  #返回的是本文件imdb.py中的self._name

注意:如果再次訓(xùn)練的時候修改了train數(shù)據(jù)庫异袄,增加或者刪除了一些數(shù)據(jù)通砍,再想重新訓(xùn)練的時候,一定要先刪除這個output中的.pkl文件烤蜕。因為如果不刪除的話封孙,就會自動加載舊的pkl文件垢揩,而不會生成新的pkl文件。

  • 子函數(shù)def _load_pascal_annotation(self, index)敛瓷,這個函數(shù)是讀取圖片gt的具體實現(xiàn)
  def _load_pascal_annotation(self, index):
    """
    Load image and bounding boxes info from XML file in the PASCAL VOC
    format.
   從XML文件中獲取圖片信息和gt叁巨。
   這個XML文件存儲的是PASCAL VOC圖片的信息和gt的信息,下載VOC數(shù)據(jù)集的時候呐籽,XML文件是一塊下載下來的锋勺。在文件夾Annotation里面。
    """
    filename = os.path.join(self._data_path, 'Annotations', index + '.xml')  
#這個filename就是一個XML文件的路徑,其中index是一張圖片的名字(沒有后綴)狡蝶。例如VOCdevkit2007/VOC2007/Annotations/000005.xml
    tree = ET.parse(filename)
    objs = tree.findall('object')
    if not self.config['use_diff']:
      # Exclude the samples labeled as difficult
      non_diff_objs = [
        obj for obj in objs if int(obj.find('difficult').text) == 0]
      # if len(non_diff_objs) != len(objs):
      #     print 'Removed {} difficult objects'.format(
      #         len(objs) - len(non_diff_objs))
      objs = non_diff_objs
    num_objs = len(objs)  # 輸進來的圖片上的物體object的個數(shù)

    boxes = np.zeros((num_objs, 4), dtype=np.uint16)
    gt_classes = np.zeros((num_objs), dtype=np.int32)
    overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
    # "Seg" area for pascal is just the box area
    seg_areas = np.zeros((num_objs), dtype=np.float32)

    # Load object bounding boxes into a data frame.
    for ix, obj in enumerate(objs):  # 對于該圖片上每一個object
      bbox = obj.find('bndbox')   
 # pascal_voc的XML文件中給出的圖片gt的形式是左上角和右下角的坐標(biāo)
      # Make pixel indexes 0-based
      x1 = float(bbox.find('xmin').text) - 1
      y1 = float(bbox.find('ymin').text) - 1
      x2 = float(bbox.find('xmax').text) - 1
      y2 = float(bbox.find('ymax').text) - 1
#為什么要減去1庶橱?是因為VOC的數(shù)據(jù),坐標(biāo)-1贪惹,默認(rèn)坐標(biāo)從0開始(這個還有待商榷苏章,先忽略)
      cls = self._class_to_ind[obj.find('name').text.lower().strip()]
#找到該object的類別cls
      boxes[ix, :] = [x1, y1, x2, y2]
      gt_classes[ix] = cls
      overlaps[ix, cls] = 1.0
      seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
# seg_areas[ix]是該object gt的面積

    overlaps = scipy.sparse.csr_matrix(overlaps)

    return {'boxes': boxes,
            'gt_classes': gt_classes,
            'gt_overlaps': overlaps,
            'flipped': False,
            'seg_areas': seg_areas}
  • 子函數(shù)def image_path_at(self, i)
  def image_path_at(self, i):
    """
    Return the absolute path to image i in the image sequence.
    """
    return self.image_path_from_index(self._image_index[i])

根據(jù)第i個圖像樣本返回其對應(yīng)的path,其調(diào)用了image_path_from_index(self, index)作為其具體實現(xiàn)奏瞬。

  • 子函數(shù)def image_path_from_index(self, index)
  def image_path_from_index(self, index):
    """
    Construct an image path from the image's "index" identifier.
    """
    image_path = os.path.join(self._data_path, 'JPEGImages',
                              index + self._image_ext) 
 #這個就是圖片本身所在的路徑枫绅。其中index是一張圖片的名字(沒有后綴),_image_ext是圖片后綴名.jpg硼端。例如VOCdevkit2007/VOC2007/JPEGImages/000005.jpg
    assert os.path.exists(image_path), \
      'Path does not exist: {}'.format(image_path)
# 如果該路徑不存在并淋,退出
    return image_path

以上可見,pascal_voc.py用了較多的路徑拼接

3 修改模型文件配置

  • 修改config.py
    工程tf-faster-rcnn中模型的參數(shù)都在文件tf-faster-rcnn/lib/model/config.py中被定義珍昨。
# Images to use per minibatch
__C.TRAIN.IMS_PER_BATCH = 1  #每次輸入到faster-rcnn網(wǎng)絡(luò)中的圖片數(shù)量是1張

# Iterations between snapshots
__C.TRAIN.SNAPSHOT_ITERS = 5000 # 訓(xùn)練的時候县耽,每5000步保存一次模型。

# solver.prototxt specifies the snapshot path prefix, this adds an optional
# infix to yield the path: <prefix>[_<infix>]_iters_XYZ.caffemodel
__C.TRAIN.SNAPSHOT_PREFIX = 'res101_faster_rcnn'  #模型在保存時的名字

# Use RPN to detect objects
__C.TRAIN.HAS_RPN = True #是否使用RPN镣典。True代表使用RPN
  • demo.py分析
CLASSES = ('__background__',
           '"seaurchin"', '"scallop"', '"seacucumber"')
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_15000.ckpt',),
'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

def vis_detections模塊:畫出測試圖片的bounding boxes兔毙, 參數(shù)im為測試圖片; class_name 為類別名稱兄春,在前面定義的 CLASSES 中澎剥; dets為非極大值抑制后的bbox和score的數(shù)組;thresh是最后score的閾值神郊,高于該閾值的候選框才會被畫出來肴裙。

def vis_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""  
    ##選取候選框score大于閾值的dets
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return
# python-opencv 中讀取圖片默認(rèn)保存為[w,h,channel](w,h順序不確定)
    # 其中 channel:BGR 存儲,而畫圖時涌乳,需要按RGB格式,因此此處作轉(zhuǎn)換甜癞。

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:   #從dets中取出 bbox, score
        bbox = dets[i, :4]
        score = dets[i, -1]
#  根據(jù)起始點坐標(biāo)以及w,h 畫出矩形框
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()

def demo模塊:對測試圖片提取預(yù)選框夕晓,并進行非極大值抑制,然后調(diào)用def vis_detections 畫矩形框悠咱。參數(shù):net 測試時使用的網(wǎng)絡(luò)結(jié)構(gòu)蒸辆;image_name:圖片名稱征炼。

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    # Visualize detections for each class
    CONF_THRESH = 0.8   
#score 閾值,最后畫出候選框時需要躬贡,>thresh才會被畫出
    NMS_THRESH = 0.3
 #非極大值抑制的閾值谆奥,剔除重復(fù)候選框
    for cls_ind, cls in enumerate(CLASSES[1:]):
  #利用enumerate函數(shù),獲得CLASSES中 類別的下標(biāo)cls_ind和類別名cls
        cls_ind += 1 # because we skipped background
#將bbox,score 一起存入dets
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]  
# because we skipped background
        cls_scores = scores[:, cls_ind]
   #取出bbox ,score#將bbox,score 一起存入dets
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
 #進行非極大值抑制拂玻,得到抑制后的 dets
        dets = dets[keep, :]    #畫框
        vis_detections(im, cls, dets, thresh=CONF_THRESH)
  • def parse_args模塊:解析命令行參數(shù)酸些,得到gpu||cpu, net等。
def parse_args():
    """Parse input arguments."""
 #創(chuàng)建解析對象
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    #parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
    #                   choices=NETS.keys(), default='res101')  #default
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='vgg16')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
#調(diào)用parser.parse_args進行解析檐蚜,返回帶標(biāo)注的args
    args = parser.parse_args()

    return args
  • 主函數(shù)
if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
#解析
    args = parse_args()
#添加路徑
    cfg.USE_GPU_NMS = False
    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
                              NETS[demonet][0])


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST", 21,
                          tag='default', anchor_scales=[8, 16, 32])
#用自己的數(shù)據(jù)集測試時魄懂,21根據(jù)classes類別數(shù)量修改
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))

    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']  #default
    im_names = ['000000.jpg']
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(sess, net, im_name)

plt.show()
  • 根據(jù)自己的數(shù)據(jù)集訓(xùn)練好模型后,要想運行Demo并將所有類別在同一圖片顯示闯第,需要按照如下代碼進行修改調(diào)整市栗。
#!/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

CLASSES = ('__background__',
           '"seaurchin"', '"scallop"', '"seacucumber"')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

#增加ax參數(shù),即第4項
def vis_detections(im, class_name, dets, ax, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

#注釋原代碼的以下三行
    #im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    #ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          #edgecolor='red', linewidth=3.5)
                          edgecolor='red', linewidth=1)  
                          # 矩形線寬從3.5改為1咳短,紅框變細
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
#注釋原代碼以下三行
    #plt.axis('off')
    #plt.tight_layout()
    #plt.draw()

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
 # 將vis_detections 函數(shù)中for 循環(huán)之前的3行代碼移動到這里
    im = im[:, :, (2, 1, 0)]
    fig,ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')

    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        #將ax做為參數(shù)傳入vis_detections填帽,即增加第4項
        vis_detections(im, cls, dets,ax,thresh=CONF_THRESH)
    # 將vis_detections 函數(shù)中for 循環(huán)之后的3行代碼移動到這里
    plt.axis('off')
    plt.tight_layout()
    plt.draw()



def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    #parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
    #                   choices=NETS.keys(), default='res101')  #default
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='vgg16')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()
    cfg.USE_GPU_NMS = False
    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
                              NETS[demonet][0])


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST", 4,
                          tag='default', anchor_scales=[8, 16, 32])
#net.create_architecture第2個參數(shù)是需要識別的類別+1,本例有3個待識別物體咙好,加background共計4
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))

    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']  #default
    im_names = ['000337.jpg']   #測試的圖片盲赊,保存在tf-faster-rcnn-contest/data/demo 路徑
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(sess, net, im_name)

plt.show()

  • 根據(jù)自己的數(shù)據(jù)集訓(xùn)練好模型后,要想運行demo.py批量處理測試圖片敷扫,并將所有類別在同一圖片顯示哀蘑,需要按照如下代碼進行修改調(diào)整。
#!/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

import scipy.io as sio
import os, sys, cv2
import argparse

import os
import numpy
from PIL import Image   #導(dǎo)入Image模塊
from pylab import *     #導(dǎo)入savetxt模塊

CLASSES = ('__background__',
           'holothurian', 'echinus', 'scallop', 'starfish')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

def vis_detections(im, class_name, dets, ax, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    #im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    #ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          #edgecolor='red', linewidth=3.5)
                          edgecolor='red', linewidth=1)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    #plt.axis('off')
    #plt.tight_layout()
    #plt.draw()

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    save_jpg = os.path.join('/data/test',im_name)

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    im = im[:, :, (2, 1, 0)]
    fig,ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')

    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]

        vis_detections(im, cls, dets,ax,thresh=CONF_THRESH)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()

def get_imlist(path):  # 此函數(shù)讀取特定文件夾下的jpg格式圖像
    return [os.path.join(f) for f in os.listdir(path) if f.endswith('.jpg')]



def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    #parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
    #                   choices=NETS.keys(), default='res101')  #default
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='vgg16')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()
    cfg.USE_GPU_NMS = False
    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
                              NETS[demonet][0])


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST",5,
                          tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))


    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']  #default
    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']
    im_names = get_imlist(r"/home/ouc/LiuHongzhi/tf-faster-rcnn-contest -2018/data/demo")
    print(im_names)
    for im_name in im_names:
    #path = "/home/henry/Files/URPC2018/VOC/VOC2007/JPEGImages/G0024172/*.jpg"
    #filelist = os.listdir(path)
    #for im_name in path:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(sess, net, im_name)
        plt.savefig("testfigs/" + im_name)
#plt.show()
  • 根據(jù)自己的數(shù)據(jù)集訓(xùn)練好模型后葵第,要想運行demo.py批量處理測試圖片绘迁,并按照<image_id> <class_id> <confidence> <xmin> <ymin> <xmax> <ymax>格式輸出信息,需要按照如下代碼進行修改調(diào)整卒密。
#!/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import os.path
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

import scipy.io as sio
import os, sys, cv2
import argparse

import os
import numpy
from PIL import Image   #導(dǎo)入Image模塊
from pylab import *     #導(dǎo)入savetxt模塊

CLASSES = ('__background__',
           'holothurian', 'echinus', 'scallop', 'starfish')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}

DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

def vis_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    #im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    #ax.imshow(im, aspect='equal')

    # !/usr/bin/env python
    # -*- coding: UTF-8 -*-
    # --------------------------------------------------------
    # Faster R-CNN
    # Copyright (c) 2015 Microsoft
    # Licensed under The MIT License [see LICENSE for details]
    # Written by Ross Girshick
    # --------------------------------------------------------

    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]
        if class_name == '__background__':
            fw = open('result.txt', 'a')  # 最終的txt保存在這個路徑下缀台,下面的都改
            fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
            fw.close()

        elif class_name == 'holothurian':
               fw = open('result.txt', 'a')  # 最終的txt保存在這個路徑下,下面的都改
               fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
               fw.close()


        elif class_name == 'echinus':
             fw = open('result.txt', 'a')  # 最終的txt保存在這個路徑下哮奇,下面的都改
             fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
             fw.close()

        elif class_name == 'scallop':
              fw = open('result.txt', 'a')  # 最終的txt保存在這個路徑下膛腐,下面的都改
              fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
              fw.close()

        elif class_name == 'starfish':
              fw = open('result.txt', 'a')  # 最終的txt保存在這個路徑下,下面的都改
              fw.write(str(im_name) + ' ' + class_name + ' ' + str(score) + ' ' +str(int(bbox[0])) + ' ' + str(int(bbox[1])) + ' ' + str(int(bbox[2])) + ' ' + str(int(bbox[3])) + '\n')
              fw.close()

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    save_jpg = os.path.join('/data/test',im_name)

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    #im = im[:, :, (2, 1, 0)]
    #fig,ax = plt.subplots(figsize=(12, 12))
    #ax.imshow(im, aspect='equal')

    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]

        vis_detections(im, cls, dets,thresh=CONF_THRESH)

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    #parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
    #                   choices=NETS.keys(), default='res101')  #default
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='vgg16')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()
    cfg.USE_GPU_NMS = False
    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
                              NETS[demonet][0])


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST",5,
                          tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))


    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']  #default
    #im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
    #           '001763.jpg', '004545.jpg']

    im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
    for i in range(2):
        _, _= im_detect(sess,net, im)

    #im_names = get_imlist(r"/home/henry/Files/tf-faster-rcnn-contest/data/demo")
    fr = open('/home/ouc/LiuHongzhi/tf-faster-rcnn-contest -2018/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt', 'r')
    for im_name in fr:
    #path = "/home/henry/Files/URPC2018/VOC/VOC2007/JPEGImages/G0024172/*.jpg"
    #filelist = os.listdir(path)
    #for im_name in path:
        im_name = im_name.strip('\n')
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(sess, net, im_name)
#plt.show()
fr.close

輸出效果如下:

000646.jpg echinus 0.9531797 617 89 785 272
000646.jpg echinus 0.94367296 200 272 396 495
000646.jpg echinus 0.9090044 953 259 1112 443
000646.jpg scallop 0.8987418 1508 975 1580 1037
000646.jpg scallop 0.8006968 512 169 580 218
000646.jpg starfish 0.96790546 291 675 390 765
001834.jpg echinus 0.9706842 291 222 365 280
001834.jpg echinus 0.965007 511 161 588 229
001834.jpg echinus 0.95911396 2 184 136 283

4 知識點補充

  • argparse
    argparse是python用于解析命令行參數(shù)和選項的標(biāo)準(zhǔn)模塊鼎俘,用于代替已經(jīng)過時的optparse模塊哲身。argparse模塊的作用是用于解析命令行參數(shù),例如python parseTest.py input.txt output.txt –user=name –port=8080贸伐。
    使用步驟:
    1:import argparse
    2:parser = argparse.ArgumentParser()
    3:parser.add_argument()
    4:parser.parse_args()
    解釋:首先導(dǎo)入該模塊勘天;然后創(chuàng)建一個解析對象;然后向該對象中添加你要關(guān)注的命令行參數(shù)和選項,每一個add_argument方法對應(yīng)一個你要關(guān)注的參數(shù)或選項脯丝;最后調(diào)用parse_args()方法進行解析商膊;

  • IoU非極大值抑制
    IoU參考

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市宠进,隨后出現(xiàn)的幾起案子晕拆,更是在濱河造成了極大的恐慌,老刑警劉巖材蹬,帶你破解...
    沈念sama閱讀 211,290評論 6 491
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件实幕,死亡現(xiàn)場離奇詭異,居然都是意外死亡赚导,警方通過查閱死者的電腦和手機茬缩,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,107評論 2 385
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來吼旧,“玉大人凰锡,你說我怎么就攤上這事∪Π担” “怎么了掂为?”我有些...
    開封第一講書人閱讀 156,872評論 0 347
  • 文/不壞的土叔 我叫張陵,是天一觀的道長员串。 經(jīng)常有香客問我勇哗,道長,這世上最難降的妖魔是什么寸齐? 我笑而不...
    開封第一講書人閱讀 56,415評論 1 283
  • 正文 為了忘掉前任欲诺,我火速辦了婚禮,結(jié)果婚禮上渺鹦,老公的妹妹穿的比我還像新娘扰法。我一直安慰自己,他們只是感情好毅厚,可當(dāng)我...
    茶點故事閱讀 65,453評論 6 385
  • 文/花漫 我一把揭開白布塞颁。 她就那樣靜靜地躺著,像睡著了一般吸耿。 火紅的嫁衣襯著肌膚如雪祠锣。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,784評論 1 290
  • 那天咽安,我揣著相機與錄音伴网,去河邊找鬼。 笑死板乙,一個胖子當(dāng)著我的面吹牛是偷,可吹牛的內(nèi)容都是我干的拳氢。 我是一名探鬼主播募逞,決...
    沈念sama閱讀 38,927評論 3 406
  • 文/蒼蘭香墨 我猛地睜開眼蛋铆,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了放接?” 一聲冷哼從身側(cè)響起刺啦,我...
    開封第一講書人閱讀 37,691評論 0 266
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎纠脾,沒想到半個月后玛瘸,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,137評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡苟蹈,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,472評論 2 326
  • 正文 我和宋清朗相戀三年糊渊,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片慧脱。...
    茶點故事閱讀 38,622評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡渺绒,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出菱鸥,到底是詐尸還是另有隱情宗兼,我是刑警寧澤,帶...
    沈念sama閱讀 34,289評論 4 329
  • 正文 年R本政府宣布氮采,位于F島的核電站殷绍,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏鹊漠。R本人自食惡果不足惜主到,卻給世界環(huán)境...
    茶點故事閱讀 39,887評論 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望躯概。 院中可真熱鬧登钥,春花似錦、人聲如沸楞陷。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,741評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽固蛾。三九已至结执,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間艾凯,已是汗流浹背献幔。 一陣腳步聲響...
    開封第一講書人閱讀 31,977評論 1 265
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留趾诗,地道東北人蜡感。 一個月前我還...
    沈念sama閱讀 46,316評論 2 360
  • 正文 我出身青樓蹬蚁,卻偏偏與公主長得像,于是被迫代替她去往敵國和親郑兴。 傳聞我的和親對象是個殘疾皇子犀斋,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 43,490評論 2 348

推薦閱讀更多精彩內(nèi)容