object_detectionAPI源碼閱讀筆記(16-通過config文件查看源碼)

這里的源碼是從train.py開始看的迅栅。之后還有eval.py

train.py

在trian.py中config文件适秩,被分成三分

  model_config = configs['model']
  train_config = configs['train_config']
  input_config = configs['train_input_config']

moedel_config是構(gòu)建模型的文件陆蟆。

  model_fn = functools.partial(
      model_builder.build,
      model_config=model_config,
      is_training=True)

在model_bulider.py中build會選擇模型種類

def build(model_config, is_training):
  """Builds a DetectionModel based on the model config.

  Args:
    model_config: A model.proto object containing the config for the desired
      DetectionModel.
    is_training: True if this model is being built for training purposes.

  Returns:
    DetectionModel based on the config.

  Raises:
    ValueError: On invalid meta architecture or model.
  """
  if not isinstance(model_config, model_pb2.DetectionModel):
    raise ValueError('model_config not of type model_pb2.DetectionModel.')
  meta_architecture = model_config.WhichOneof('model')
  if meta_architecture == 'ssd':
    return _build_ssd_model(model_config.ssd, is_training)
  if meta_architecture == 'faster_rcnn':
    return _build_faster_rcnn_model(model_config.faster_rcnn, is_training)
  raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))

如果你選擇faster-rcnn,在model_builder.py中這些都是構(gòu)建faster-rcnn模型的參數(shù)

如果你有興趣,在protos/model_pb2.py有很多model_config的默認值

這時候模型已經(jīng)構(gòu)建完了

回到train.py中

  train_config = configs['train_config']

發(fā)現(xiàn)這是對trainer.py進行的配置文件瞬逊,在trainer.py的train函數(shù)中哎媚,如下:

在protos/train_pb2.py中的默認配置如下:

_descriptor.FieldDescriptor(
      name='batch_size', full_name='object_detection.protos.TrainConfig.batch_size', index=0,
      number=1, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=32,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='data_augmentation_options', full_name='object_detection.protos.TrainConfig.data_augmentation_options', index=1,
      number=2, type=11, cpp_type=10, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='sync_replicas', full_name='object_detection.protos.TrainConfig.sync_replicas', index=2,
      number=3, type=8, cpp_type=7, label=1,
      has_default_value=True, default_value=False,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='keep_checkpoint_every_n_hours', full_name='object_detection.protos.TrainConfig.keep_checkpoint_every_n_hours', index=3,
      number=4, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=1000,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='optimizer', full_name='object_detection.protos.TrainConfig.optimizer', index=4,
      number=5, type=11, cpp_type=10, label=1,
      has_default_value=False, default_value=None,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='gradient_clipping_by_norm', full_name='object_detection.protos.TrainConfig.gradient_clipping_by_norm', index=5,
      number=6, type=2, cpp_type=6, label=1,
      has_default_value=True, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='fine_tune_checkpoint', full_name='object_detection.protos.TrainConfig.fine_tune_checkpoint', index=6,
      number=7, type=9, cpp_type=9, label=1,
      has_default_value=True, default_value=_b("").decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='from_detection_checkpoint', full_name='object_detection.protos.TrainConfig.from_detection_checkpoint', index=7,
      number=8, type=8, cpp_type=7, label=1,
      has_default_value=True, default_value=False,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='num_steps', full_name='object_detection.protos.TrainConfig.num_steps', index=8,
      number=9, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='startup_delay_steps', full_name='object_detection.protos.TrainConfig.startup_delay_steps', index=9,
      number=10, type=2, cpp_type=6, label=1,
      has_default_value=True, default_value=15,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='bias_grad_multiplier', full_name='object_detection.protos.TrainConfig.bias_grad_multiplier', index=10,
      number=11, type=2, cpp_type=6, label=1,
      has_default_value=True, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='freeze_variables', full_name='object_detection.protos.TrainConfig.freeze_variables', index=11,
      number=12, type=9, cpp_type=9, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='replicas_to_aggregate', full_name='object_detection.protos.TrainConfig.replicas_to_aggregate', index=12,
      number=13, type=5, cpp_type=1, label=1,
      has_default_value=True, default_value=1,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='batch_queue_capacity', full_name='object_detection.protos.TrainConfig.batch_queue_capacity', index=13,
      number=14, type=5, cpp_type=1, label=1,
      has_default_value=True, default_value=150,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='num_batch_queue_threads', full_name='object_detection.protos.TrainConfig.num_batch_queue_threads', index=14,
      number=15, type=5, cpp_type=1, label=1,
      has_default_value=True, default_value=8,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='prefetch_queue_capacity', full_name='object_detection.protos.TrainConfig.prefetch_queue_capacity', index=15,
      number=16, type=5, cpp_type=1, label=1,
      has_default_value=True, default_value=5,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='merge_multiple_label_boxes', full_name='object_detection.protos.TrainConfig.merge_multiple_label_boxes', index=16,
      number=17, type=8, cpp_type=7, label=1,
      has_default_value=True, default_value=False,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),

再看input_config

 input_config = configs['train_input_config']

在builder/input_reader_builder中

input_reader_pb2中默認值:

    _descriptor.FieldDescriptor(
      name='label_map_path', full_name='object_detection.protos.InputReader.label_map_path', index=0,
      number=1, type=9, cpp_type=9, label=1,
      has_default_value=True, default_value=_b("").decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='shuffle', full_name='object_detection.protos.InputReader.shuffle', index=1,
      number=2, type=8, cpp_type=7, label=1,
      has_default_value=True, default_value=True,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='queue_capacity', full_name='object_detection.protos.InputReader.queue_capacity', index=2,
      number=3, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=2000,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='min_after_dequeue', full_name='object_detection.protos.InputReader.min_after_dequeue', index=3,
      number=4, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=1000,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='num_epochs', full_name='object_detection.protos.InputReader.num_epochs', index=4,
      number=5, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='num_readers', full_name='object_detection.protos.InputReader.num_readers', index=5,
      number=6, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=8,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='load_instance_masks', full_name='object_detection.protos.InputReader.load_instance_masks', index=6,
      number=7, type=8, cpp_type=7, label=1,
      has_default_value=True, default_value=False,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='tf_record_input_reader', full_name='object_detection.protos.InputReader.tf_record_input_reader', index=7,
      number=8, type=11, cpp_type=10, label=1,
      has_default_value=False, default_value=None,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='external_input_reader', full_name='object_detection.protos.InputReader.external_input_reader', index=8,
      number=9, type=11, cpp_type=10, label=1,
      has_default_value=False, default_value=None,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),

value.py

在value.py中也是被分為三個部分

  model_config = configs['model']
  eval_config = configs['eval_config']
  if FLAGS.eval_training_data:
    input_config = configs['train_input_config']
  else:
    input_config = configs['eval_input_config']

這里

eval_config = configs['eval_config']

為新增的一個配置文件垫言,進行計算評估用的一個文件。

evaluator.py文件使用了這里的config文件參數(shù)

文件開頭的幾種分數(shù)評估方式瞻讽。

EVAL_METRICS_CLASS_DICT = {
    'pascal_voc_metrics':
        object_detection_evaluation.PascalDetectionEvaluator,
    'weighted_pascal_voc_metrics':
        object_detection_evaluation.WeightedPascalDetectionEvaluator,
    'open_images_metrics':
        object_detection_evaluation.OpenImagesDetectionEvaluator
}

eval_pb2.py文件中的eval_config的默認值鸳吸。

_descriptor.FieldDescriptor(
      name='num_visualizations', full_name='object_detection.protos.EvalConfig.num_visualizations', index=0,
      number=1, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=10,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='num_examples', full_name='object_detection.protos.EvalConfig.num_examples', index=1,
      number=2, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=5000,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='eval_interval_secs', full_name='object_detection.protos.EvalConfig.eval_interval_secs', index=2,
      number=3, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=300,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='max_evals', full_name='object_detection.protos.EvalConfig.max_evals', index=3,
      number=4, type=13, cpp_type=3, label=1,
      has_default_value=True, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='save_graph', full_name='object_detection.protos.EvalConfig.save_graph', index=4,
      number=5, type=8, cpp_type=7, label=1,
      has_default_value=True, default_value=False,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='visualization_export_dir', full_name='object_detection.protos.EvalConfig.visualization_export_dir', index=5,
      number=6, type=9, cpp_type=9, label=1,
      has_default_value=True, default_value=_b("").decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='eval_master', full_name='object_detection.protos.EvalConfig.eval_master', index=6,
      number=7, type=9, cpp_type=9, label=1,
      has_default_value=True, default_value=_b("").decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='metrics_set', full_name='object_detection.protos.EvalConfig.metrics_set', index=7,
      number=8, type=9, cpp_type=9, label=1,
      has_default_value=True, default_value=_b("pascal_voc_metrics").decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='export_path', full_name='object_detection.protos.EvalConfig.export_path', index=8,
      number=9, type=9, cpp_type=9, label=1,
      has_default_value=True, default_value=_b("").decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='ignore_groundtruth', full_name='object_detection.protos.EvalConfig.ignore_groundtruth', index=9,
      number=10, type=8, cpp_type=7, label=1,
      has_default_value=True, default_value=False,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='use_moving_averages', full_name='object_detection.protos.EvalConfig.use_moving_averages', index=10,
      number=11, type=8, cpp_type=7, label=1,
      has_default_value=True, default_value=False,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
    _descriptor.FieldDescriptor(
      name='eval_instance_masks', full_name='object_detection.protos.EvalConfig.eval_instance_masks', index=11,
      number=12, type=8, cpp_type=7, label=1,
      has_default_value=True, default_value=False,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None)

所有的超參數(shù)的默認值都可以在config文件中進行修改熏挎。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末速勇,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子坎拐,更是在濱河造成了極大的恐慌烦磁,老刑警劉巖,帶你破解...
    沈念sama閱讀 221,820評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件哼勇,死亡現(xiàn)場離奇詭異都伪,居然都是意外死亡,警方通過查閱死者的電腦和手機积担,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,648評論 3 399
  • 文/潘曉璐 我一進店門陨晶,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人帝璧,你說我怎么就攤上這事先誉。” “怎么了的烁?”我有些...
    開封第一講書人閱讀 168,324評論 0 360
  • 文/不壞的土叔 我叫張陵褐耳,是天一觀的道長。 經(jīng)常有香客問我渴庆,道長铃芦,這世上最難降的妖魔是什么雅镊? 我笑而不...
    開封第一講書人閱讀 59,714評論 1 297
  • 正文 為了忘掉前任,我火速辦了婚禮刃滓,結(jié)果婚禮上仁烹,老公的妹妹穿的比我還像新娘。我一直安慰自己注盈,他們只是感情好晃危,可當我...
    茶點故事閱讀 68,724評論 6 397
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著老客,像睡著了一般僚饭。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上胧砰,一...
    開封第一講書人閱讀 52,328評論 1 310
  • 那天鳍鸵,我揣著相機與錄音,去河邊找鬼尉间。 笑死偿乖,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的哲嘲。 我是一名探鬼主播贪薪,決...
    沈念sama閱讀 40,897評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼眠副!你這毒婦竟也來了画切?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,804評論 0 276
  • 序言:老撾萬榮一對情侶失蹤囱怕,失蹤者是張志新(化名)和其女友劉穎霍弹,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體娃弓,經(jīng)...
    沈念sama閱讀 46,345評論 1 318
  • 正文 獨居荒郊野嶺守林人離奇死亡典格,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,431評論 3 340
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了台丛。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片耍缴。...
    茶點故事閱讀 40,561評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖挽霉,靈堂內(nèi)的尸體忽然破棺而出防嗡,到底是詐尸還是另有隱情,我是刑警寧澤炼吴,帶...
    沈念sama閱讀 36,238評論 5 350
  • 正文 年R本政府宣布本鸣,位于F島的核電站,受9級特大地震影響硅蹦,放射性物質(zhì)發(fā)生泄漏荣德。R本人自食惡果不足惜闷煤,卻給世界環(huán)境...
    茶點故事閱讀 41,928評論 3 334
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望涮瞻。 院中可真熱鬧鲤拿,春花似錦、人聲如沸署咽。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,417評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽宁否。三九已至窒升,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間慕匠,已是汗流浹背饱须。 一陣腳步聲響...
    開封第一講書人閱讀 33,528評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留台谊,地道東北人蓉媳。 一個月前我還...
    沈念sama閱讀 48,983評論 3 376
  • 正文 我出身青樓,卻偏偏與公主長得像锅铅,于是被迫代替她去往敵國和親酪呻。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 45,573評論 2 359

推薦閱讀更多精彩內(nèi)容