在使用google提供的模型時污茵,可能會報錯舞肆,那就要我們了解程序加載模型的具體細節(jié)碱屁。
訓練的開始語句是
python train.py --train_dir='train_dir' --pipeline_config_path='pipeline_config_path'
核心調(diào)用的文件就是train.py刁绒。在train.py中通過
model_config, train_config, input_config = get_configs_from_pipeline_file()
獲取配置信息,其中是調(diào)用protobuf進行文件解析坊罢。之后分別得到model_config续担,train_config,input_config(訓練的格式)活孩。
之后通過functools.partial函數(shù)對model_builder.build函數(shù)賦予默認值物遇。通過functools.partial函數(shù)對input_reader_builder.build賦予默認值。train_config是在最后訓練的時候進行傳入憾儒。
生成網(wǎng)絡(luò)模型的代碼為:
def build(model_config, is_training):
if not isinstance(model_config, model_pb2.DetectionModel):
raise ValueError('model_config not of type model_pb2.DetectionModel.')
# 獲取配置中的模型種類
meta_architecture = model_config.WhichOneof('model')
#進行具體加載
if meta_architecture == 'ssd':
return _build_ssd_model(model_config.ssd, is_training)
if meta_architecture == 'faster_rcnn':
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training)
raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
之后以'faster_rcnn模型為例子询兴,進入_build_faster_rcnn_model
def _build_faster_rcnn_model(frcnn_config, is_training):
#構(gòu)建一個Faster R-CNN 或者 R-FCN的檢測模型。
#如果second_stage_box_predictor的類型是rfcn_box_predictor則為R-FCN模型起趾,否則為Faster R-CNN
#frcnn_config 說明模型的結(jié)構(gòu)
#is_training 模型是否用來進行訓練蕉朵。
#獲取要識別的類數(shù)
num_classes = frcnn_config.num_classes
#構(gòu)建圖像歸一化
image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)
#構(gòu)建前端網(wǎng)絡(luò)結(jié)構(gòu)
feature_extractor = _build_faster_rcnn_feature_extractor(
frcnn_config.feature_extractor, is_training)
#是否僅構(gòu)建RPN層
first_stage_only = frcnn_config.first_stage_only
#構(gòu)建Anchor
first_stage_anchor_generator = anchor_generator_builder.build(
frcnn_config.first_stage_anchor_generator)
#帶間隔的卷積,其中的間隔多大
first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
#卷積神經(jīng)網(wǎng)絡(luò)的參數(shù)
first_stage_box_predictor_arg_scope = hyperparams_builder.build(
frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
#第一階段的核大小
first_stage_box_predictor_kernel_size = (
frcnn_config.first_stage_box_predictor_kernel_size)
#rpn的輸出深度
first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
#第一階段的最小批次
first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
#每一張圖片RPN中正樣本的數(shù)量阳掐。
first_stage_positive_balance_fraction = (
frcnn_config.first_stage_positive_balance_fraction)
#第一階段nms得分的閾值
first_stage_nms_score_threshold = frcnn_config.first_stage_nms_score_threshold
#第一階段nms的IOU的閾值
first_stage_nms_iou_threshold = frcnn_config.first_stage_nms_iou_threshold
#第一階段最多傳出多少個RPN
first_stage_max_proposals = frcnn_config.first_stage_max_proposals
#第一階段的定位損失權(quán)重
first_stage_loc_loss_weight = (
frcnn_config.first_stage_localization_loss_weight)
#第一階段的物品損失權(quán)重
first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight
#輸出的rpn的大小(寬高相等)
initial_crop_size = frcnn_config.initial_crop_size
#在maxpool時的核的大小以及步長
maxpool_kernel_size = frcnn_config.maxpool_kernel_size
maxpool_stride = frcnn_config.maxpool_stride
#構(gòu)建卷積神經(jīng)網(wǎng)絡(luò)的超參數(shù)
second_stage_box_predictor = box_predictor_builder.build(
hyperparams_builder.build,
frcnn_config.second_stage_box_predictor,
is_training=is_training,
num_classes=num_classes)
#第二階段進入的圖片的批次
second_stage_batch_size = frcnn_config.second_stage_batch_size
#第二階段中圖片中bbox的正樣本的比例
second_stage_balance_fraction = frcnn_config.second_stage_balance_fraction
#構(gòu)建后處理的模型
(second_stage_non_max_suppression_fn, second_stage_score_conversion_fn
) = post_processing_builder.build(frcnn_config.second_stage_post_processing)
#第二階段的位置和種類的權(quán)重比例
second_stage_localization_loss_weight = (
frcnn_config.second_stage_localization_loss_weight)
second_stage_classification_loss_weight = (
frcnn_config.second_stage_classification_loss_weight)
#默認不進行困難樣本發(fā)現(xiàn)
hard_example_miner = None
#但如果配置文件中有hard_example_miner,則進行困難樣本發(fā)現(xiàn)
if frcnn_config.HasField('hard_example_miner'):
hard_example_miner = losses_builder.build_hard_example_miner(
frcnn_config.hard_example_miner,
second_stage_classification_loss_weight,
second_stage_localization_loss_weight)
#將配置好的模型放入dict中
common_kwargs = {
'is_training': is_training,
'num_classes': num_classes,
'image_resizer_fn': image_resizer_fn,
'feature_extractor': feature_extractor,
'first_stage_only': first_stage_only,
'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope':
first_stage_box_predictor_arg_scope,
'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
'first_stage_minibatch_size': first_stage_minibatch_size,
'first_stage_positive_balance_fraction':
first_stage_positive_balance_fraction,
'first_stage_nms_score_threshold': first_stage_nms_score_threshold,
'first_stage_nms_iou_threshold': first_stage_nms_iou_threshold,
'first_stage_max_proposals': first_stage_max_proposals,
'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
'second_stage_batch_size': second_stage_batch_size,
'second_stage_balance_fraction': second_stage_balance_fraction,
'second_stage_non_max_suppression_fn':
second_stage_non_max_suppression_fn,
'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
'second_stage_localization_loss_weight':
second_stage_localization_loss_weight,
'second_stage_classification_loss_weight':
second_stage_classification_loss_weight,
'hard_example_miner': hard_example_miner}
#如果第二階段是rfcn的則使用上面這個缭保,否則使用上面這個
if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor):
return rfcn_meta_arch.RFCNMetaArch(
second_stage_rfcn_box_predictor=second_stage_box_predictor,
**common_kwargs)
else:
return faster_rcnn_meta_arch.FasterRCNNMetaArch(
initial_crop_size=initial_crop_size,
maxpool_kernel_size=maxpool_kernel_size,
maxpool_stride=maxpool_stride,
second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
**common_kwargs)
之后說明每一個子模型的構(gòu)建
首先是image_resizer_builder的模型構(gòu)建
# 構(gòu)建圖片的resize
def build(image_resizer_config):
# 查看類型是否正確
if not isinstance(image_resizer_config, image_resizer_pb2.ImageResizer):
raise ValueError('image_resizer_config not of type '
'image_resizer_pb2.ImageResizer.')
#查看是否設(shè)置了image_resizer_oneof屬性汛闸,如果有判斷是否為keep_aspect_ratio_resizer
if image_resizer_config.WhichOneof(
'image_resizer_oneof') == 'keep_aspect_ratio_resizer':
#如果是則進行保持圖片比例的縮放,再使用functools.partial對 preprocessor.resize_to_range給默認值艺骂。
keep_aspect_ratio_config = image_resizer_config.keep_aspect_ratio_resizer
if not (keep_aspect_ratio_config.min_dimension
<= keep_aspect_ratio_config.max_dimension):
raise ValueError('min_dimension > max_dimension')
return functools.partial(
preprocessor.resize_to_range,
min_dimension=keep_aspect_ratio_config.min_dimension,
max_dimension=keep_aspect_ratio_config.max_dimension)
#如果有image_resizer_oneof屬性诸老,如果有判斷是否為fixed_shape_resizer,即歸一化到固定大小
if image_resizer_config.WhichOneof(
'image_resizer_oneof') == 'fixed_shape_resizer':
#如果有則使用functools.partial對preprocessor.resize_image,給默認值钳恕,插值的那種resize
fixed_shape_resizer_config = image_resizer_config.fixed_shape_resizer
return functools.partial(preprocessor.resize_image,
new_height=fixed_shape_resizer_config.height,
new_width=fixed_shape_resizer_config.width)
raise ValueError('Invalid image resizer option.')
接下來看preprocessor.resize_to_range這個函數(shù)
def resize_to_range(image,
masks=None,
min_dimension=None,
max_dimension=None,
align_corners=False):
#該函數(shù)是將一個圖片resize到給定的大小
#其中别伏,有兩種可能:
#1.如果圖片可以resize到短邊等于給定的值,而長邊不超過給定的max_dimension
#2.將長邊resize到max_dimension忧额。
if len(image.get_shape()) != 3:
raise ValueError('Image should be 3D tensor')
with tf.name_scope('ResizeToRange', values=[image, min_dimension]):
if image.get_shape().is_fully_defined():
new_size = _compute_new_static_size(image, min_dimension,
max_dimension)
else:
new_size = _compute_new_dynamic_size(image, min_dimension,
max_dimension)
new_image = tf.image.resize_images(image, new_size,
align_corners=align_corners)
result = new_image
if masks is not None:
new_masks = tf.expand_dims(masks, 3)
new_masks = tf.image.resize_nearest_neighbor(new_masks, new_size,
align_corners=align_corners)
new_masks = tf.squeeze(new_masks, 3)
result = [new_image, new_masks]
return result
resize之后就是構(gòu)建faster_rcnn_meta_arch厘肮,也就是進行_build_faster_rcnn_feature_extractor函數(shù)的說明
def _build_faster_rcnn_feature_extractor(
feature_extractor_config, is_training, reuse_weights=None):
#獲取第一階段的網(wǎng)絡(luò)結(jié)構(gòu),比如:faster_rcnn_resnet101
feature_type = feature_extractor_config.type
#獲取
#first_stage_features_stride只能等于8或者16睦番,否則會報錯
first_stage_features_stride = (
feature_extractor_config.first_stage_features_stride)
#判斷有沒有內(nèi)置的這個特征提取的網(wǎng)絡(luò)
if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP:
raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
feature_type))
feature_extractor_class = FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP[
feature_type]
# 返回了一個初始化了的特征提取
return feature_extractor_class(
is_training, first_stage_features_stride, reuse_weights)
加下來是anchor的構(gòu)建
def build(anchor_generator_config):
#構(gòu)建過程二選一类茂,是使用grid_anchor_generator還是ssd_anchor_generator在這里我們看grid_anchor_generator
if not isinstance(anchor_generator_config,
anchor_generator_pb2.AnchorGenerator):
raise ValueError('anchor_generator_config not of type '
'anchor_generator_pb2.AnchorGenerator')
if anchor_generator_config.WhichOneof(
'anchor_generator_oneof') == 'grid_anchor_generator':
grid_anchor_generator_config = anchor_generator_config.grid_anchor_generator
#使用傳入的參數(shù)對grid_anchor_generator進行初始化,具體的自行看
return grid_anchor_generator.GridAnchorGenerator(
scales=[float(scale) for scale in grid_anchor_generator_config.scales],
aspect_ratios=[float(aspect_ratio)
for aspect_ratio
in grid_anchor_generator_config.aspect_ratios],
base_anchor_size=[grid_anchor_generator_config.height,
grid_anchor_generator_config.width],
anchor_stride=[grid_anchor_generator_config.height_stride,
grid_anchor_generator_config.width_stride],
anchor_offset=[grid_anchor_generator_config.height_offset,
grid_anchor_generator_config.width_offset])
elif anchor_generator_config.WhichOneof(
'anchor_generator_oneof') == 'ssd_anchor_generator':
ssd_anchor_generator_config = anchor_generator_config.ssd_anchor_generator
return multiple_grid_anchor_generator.create_ssd_anchors(
num_layers=ssd_anchor_generator_config.num_layers,
min_scale=ssd_anchor_generator_config.min_scale,
max_scale=ssd_anchor_generator_config.max_scale,
aspect_ratios=ssd_anchor_generator_config.aspect_ratios,
reduce_boxes_in_lowest_layer=(ssd_anchor_generator_config
.reduce_boxes_in_lowest_layer))
else:
raise ValueError('Empty anchor generator.')
接下來是構(gòu)建hyperparams_builder.build的那個模塊
def build(hyperparams_config, is_training):
#根據(jù)給出的配置文件構(gòu)建tf-slim的arg_scope托嚣,
#返回的arg_scope中包含了權(quán)重的初始化巩检,歸一化,激活函數(shù)示启,BN等信息兢哭。
#如果BN沒有定義,則不包含BN層夫嗓。
#BN的參數(shù)是否進行訓練是基于is_training參數(shù)和
#conv_hyperparams_config.batch_norm.train這兩個參數(shù)迟螺。
if not isinstance(hyperparams_config,
hyperparams_pb2.Hyperparams):
raise ValueError('hyperparams_config not of type '
'hyperparams_pb.Hyperparams.')
#如果有batch_norm,則進行BN啤月,
batch_norm = None
batch_norm_params = None
if hyperparams_config.HasField('batch_norm'):
batch_norm = slim.batch_norm
batch_norm_params = _build_batch_norm_params(
hyperparams_config.batch_norm, is_training)
affected_ops = [slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose]
if hyperparams_config.HasField('op') and (
hyperparams_config.op == hyperparams_pb2.Hyperparams.FC):
affected_ops = [slim.fully_connected]
with slim.arg_scope(
affected_ops,
weights_regularizer=_build_regularizer(
hyperparams_config.regularizer),
weights_initializer=_build_initializer(
hyperparams_config.initializer),
activation_fn=_build_activation_fn(hyperparams_config.activation),
normalizer_fn=batch_norm,
normalizer_params=batch_norm_params) as sc:
return sc
接下來是bbox的預測的構(gòu)建
def build(argscope_fn, box_predictor_config, is_training, num_classes):
#基于配置文件進行box的預測的配置
#argscope_fn接受兩個輸入:hyperparams_pb2.Hyperparams proto以及表示是否進行訓練
if not isinstance(box_predictor_config, box_predictor_pb2.BoxPredictor):
raise ValueError('box_predictor_config not of type '
'box_predictor_pb2.BoxPredictor.')
#獲取配置文件中的box_predictor_oneof對應(yīng)的參數(shù)煮仇,這里給出的是rfcn_box_predictor
box_predictor_oneof = box_predictor_config.WhichOneof('box_predictor_oneof')
if box_predictor_oneof == 'convolutional_box_predictor':
conv_box_predictor = box_predictor_config.convolutional_box_predictor
conv_hyperparams = argscope_fn(conv_box_predictor.conv_hyperparams,
is_training)
box_predictor_object = box_predictor.ConvolutionalBoxPredictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams=conv_hyperparams,
min_depth=conv_box_predictor.min_depth,
max_depth=conv_box_predictor.max_depth,
num_layers_before_predictor=(conv_box_predictor.
num_layers_before_predictor),
use_dropout=conv_box_predictor.use_dropout,
dropout_keep_prob=conv_box_predictor.dropout_keep_probability,
kernel_size=conv_box_predictor.kernel_size,
box_code_size=conv_box_predictor.box_code_size,
apply_sigmoid_to_scores=conv_box_predictor.apply_sigmoid_to_scores)
return box_predictor_object
if box_predictor_oneof == 'mask_rcnn_box_predictor':
mask_rcnn_box_predictor = box_predictor_config.mask_rcnn_box_predictor
fc_hyperparams = argscope_fn(mask_rcnn_box_predictor.fc_hyperparams,
is_training)
conv_hyperparams = None
if mask_rcnn_box_predictor.HasField('conv_hyperparams'):
conv_hyperparams = argscope_fn(mask_rcnn_box_predictor.conv_hyperparams,
is_training)
box_predictor_object = box_predictor.MaskRCNNBoxPredictor(
is_training=is_training,
num_classes=num_classes,
fc_hyperparams=fc_hyperparams,
use_dropout=mask_rcnn_box_predictor.use_dropout,
dropout_keep_prob=mask_rcnn_box_predictor.dropout_keep_probability,
box_code_size=mask_rcnn_box_predictor.box_code_size,
conv_hyperparams=conv_hyperparams,
predict_instance_masks=mask_rcnn_box_predictor.predict_instance_masks,
mask_prediction_conv_depth=(mask_rcnn_box_predictor.
mask_prediction_conv_depth),
predict_keypoints=mask_rcnn_box_predictor.predict_keypoints)
return box_predictor_object
#如果是rfcn_box_predictor,則進行之后的操作
if box_predictor_oneof == 'rfcn_box_predictor':
rfcn_box_predictor = box_predictor_config.rfcn_box_predictor
#進行hyperparams_builder.build谎仲。
conv_hyperparams = argscope_fn(rfcn_box_predictor.conv_hyperparams,
is_training)
# 初始化一個box的預測器浙垫,對正樣本ROI預測類型以及位置
#用于第二階段的RFCN的結(jié)構(gòu)
box_predictor_object = box_predictor.RfcnBoxPredictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams=conv_hyperparams,
crop_size=[rfcn_box_predictor.crop_height,
rfcn_box_predictor.crop_width],
num_spatial_bins=[rfcn_box_predictor.num_spatial_bins_height,
rfcn_box_predictor.num_spatial_bins_width],
depth=rfcn_box_predictor.depth,
box_code_size=rfcn_box_predictor.box_code_size)
return box_predictor_object
raise ValueError('Unknown box predictor: {}'.format(box_predictor_oneof))
上面的函數(shù)中有hyperparams_builder.build,那么就看看這個
def build(hyperparams_config, is_training):
#其實也是返回一個tf-slim 的arg_scope郑诺。
if not isinstance(hyperparams_config,
hyperparams_pb2.Hyperparams):
raise ValueError('hyperparams_config not of type '
'hyperparams_pb.Hyperparams.')
batch_norm = None
batch_norm_params = None
if hyperparams_config.HasField('batch_norm'):
batch_norm = slim.batch_norm
batch_norm_params = _build_batch_norm_params(
hyperparams_config.batch_norm, is_training)
affected_ops = [slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose]
if hyperparams_config.HasField('op') and (
hyperparams_config.op == hyperparams_pb2.Hyperparams.FC):
affected_ops = [slim.fully_connected]
with slim.arg_scope(
affected_ops,
weights_regularizer=_build_regularizer(
hyperparams_config.regularizer),
weights_initializer=_build_initializer(
hyperparams_config.initializer),
activation_fn=_build_activation_fn(hyperparams_config.activation),
normalizer_fn=batch_norm,
normalizer_params=batch_norm_params) as sc:
return sc
已經(jīng)獲取了box以及預測的類別夹姥,之后就是要進行一些后處理,可以看看后處理的構(gòu)建post_processing_builder.build(frcnn_config.second_stage_post_processing)的具體內(nèi)容辙诞。
def build(post_processing_config):
#構(gòu)建可調(diào)用的后處理操作辙售,主要之基于配置文件對性nms以及得分排序的操作。
if not isinstance(post_processing_config, post_processing_pb2.PostProcessing):
raise ValueError('post_processing_config not of type '
'post_processing_pb2.Postprocessing.')
#構(gòu)建nms
non_max_suppressor_fn = _build_non_max_suppressor(
post_processing_config.batch_non_max_suppression)
#構(gòu)建得分排序
score_converter_fn = _build_score_converter(
post_processing_config.score_converter)
return non_max_suppressor_fn, score_converter_fn
nms的構(gòu)建飞涂,繼續(xù)看post_processing.batch_multiclass_non_max_suppression這個函數(shù)
def _build_non_max_suppressor(nms_config):
if nms_config.iou_threshold < 0 or nms_config.iou_threshold > 1.0:
raise ValueError('iou_threshold not in [0, 1.0].')
if nms_config.max_detections_per_class > nms_config.max_total_detections:
raise ValueError('max_detections_per_class should be no greater than '
'max_total_detections.')
non_max_suppressor_fn = functools.partial(
post_processing.batch_multiclass_non_max_suppression,
score_thresh=nms_config.score_threshold,
iou_thresh=nms_config.iou_threshold,
max_size_per_class=nms_config.max_detections_per_class,
max_total_size=nms_config.max_total_detections)
return non_max_suppressor_fn
不用說就是post_processing.batch_multiclass_non_max_suppression
#太長了旦部,不復制了祈搜。和multiclass_non_max_suppression很相似,具體自己看
接下來是針對loss的build_hard_example_miner
def build_hard_example_miner(config,
classification_weight,
localization_weight):
#核心是 losses.HardExampleMiner士八,由于沒有使用就不看了容燕,需要的話自己看
loss_type = None
if config.loss_type == losses_pb2.HardExampleMiner.BOTH:
loss_type = 'both'
if config.loss_type == losses_pb2.HardExampleMiner.CLASSIFICATION:
loss_type = 'cls'
if config.loss_type == losses_pb2.HardExampleMiner.LOCALIZATION:
loss_type = 'loc'
max_negatives_per_positive = None
num_hard_examples = None
if config.max_negatives_per_positive > 0:
max_negatives_per_positive = config.max_negatives_per_positive
if config.num_hard_examples > 0:
num_hard_examples = config.num_hard_examples
#只是一個初始化,具體的自己看
hard_example_miner = losses.HardExampleMiner(
num_hard_examples=num_hard_examples,
iou_threshold=config.iou_threshold,
loss_type=loss_type,
cls_loss_weight=classification_weight,
loc_loss_weight=localization_weight,
max_negatives_per_positive=max_negatives_per_positive,
min_negatives_per_image=config.min_negatives_per_image)
return hard_example_miner
函數(shù)最后也就是最重要的rfcn_meta_arch.RFCNMetaArch婚度,其實就是RFCNMetaArch的初始化蘸秘。就是構(gòu)建一個faster r-cnn的模型之后將第二階段進行替換。
class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
"""R-FCN Meta-architecture definition."""
def __init__(self,
is_training,
num_classes,
image_resizer_fn,
feature_extractor,
first_stage_only,
first_stage_anchor_generator,
first_stage_atrous_rate,
first_stage_box_predictor_arg_scope,
first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth,
first_stage_minibatch_size,
first_stage_positive_balance_fraction,
first_stage_nms_score_threshold,
first_stage_nms_iou_threshold,
first_stage_max_proposals,
first_stage_localization_loss_weight,
first_stage_objectness_loss_weight,
second_stage_rfcn_box_predictor,
second_stage_batch_size,
second_stage_balance_fraction,
second_stage_non_max_suppression_fn,
second_stage_score_conversion_fn,
second_stage_localization_loss_weight,
second_stage_classification_loss_weight,
hard_example_miner,
parallel_iterations=16):
super(RFCNMetaArch, self).__init__(
is_training,
num_classes,
image_resizer_fn,
feature_extractor,
first_stage_only,
first_stage_anchor_generator,
first_stage_atrous_rate,
first_stage_box_predictor_arg_scope,
first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth,
first_stage_minibatch_size,
first_stage_positive_balance_fraction,
first_stage_nms_score_threshold,
first_stage_nms_iou_threshold,
first_stage_max_proposals,
first_stage_localization_loss_weight,
first_stage_objectness_loss_weight,
None, # initial_crop_size is not used in R-FCN
None, # maxpool_kernel_size is not use in R-FCN
None, # maxpool_stride is not use in R-FCN
None, # fully_connected_box_predictor is not used in R-FCN.
second_stage_batch_size,
second_stage_balance_fraction,
second_stage_non_max_suppression_fn,
second_stage_score_conversion_fn,
second_stage_localization_loss_weight,
second_stage_classification_loss_weight,
hard_example_miner,
parallel_iterations)
self._rfcn_box_predictor = second_stage_rfcn_box_predictor