自然場(chǎng)景下的文本檢測(cè)和識(shí)別 EAST text detector and recognition
最近在做巡檢機(jī)器人和儀表識(shí)別算法,巡檢機(jī)器人拍攝的照片除了指針儀表和狀態(tài)燈以外,還有一部分是數(shù)字顯示的儀表,這樣對(duì)儀表的數(shù)值的識(shí)別就需要后臺(tái)代碼具備檢測(cè)文本和識(shí)別的功能了.
另外,一些項(xiàng)目中也有對(duì)移動(dòng)的車(chē)廂或者罐子上的編號(hào)做識(shí)別處理,這樣一套算法就可以搞定這些問(wèn)題了.
1. EAST text detector 模型
自然場(chǎng)景下的文本檢測(cè)模型,參考了 Zhou et al.的在arxiv上的論文方法. 論文鏈接
- 使用ResNet-50殘差網(wǎng)絡(luò)作為基礎(chǔ).
- 使用dice loss 損失函數(shù).
- 使用了 AdamW 優(yōu)化器.
import keras
from keras import layers, Input, Model
import tensorflow as tf
from east.layers.base_net import resnet50
from east.layers.losses import balanced_cross_entropy, iou_loss, angle_loss
from east.layers.rbox import dist_to_box
def merge_block(f_pre, f_cur, out_channels, index):
"""
east網(wǎng)絡(luò)特征合并塊
:param f_pre:
:param f_cur:
:param out_channels:輸出通道數(shù)
:param index:block index
:return:
"""
# 上采樣
up_sample = layers.UpSampling2D(size=2, name="east_up_sample_f{}".format(index - 1))(f_pre)
# 合并
merge = layers.Concatenate(name='east_merge_{}'.format(index))([up_sample, f_cur])
# 1*1 降維
x = layers.Conv2D(out_channels, (1, 1), padding='same', name='east_reduce_channel_conv_{}'.format(index))(merge)
x = layers.BatchNormalization(name='east_reduce_channel_bn_{}'.format(index))(x)
x = layers.Activation(activation='relu', name='east_reduce_channel_relu_{}'.format(index))(x)
# 3*3 提取特征
x = layers.Conv2D(out_channels, (3, 3), padding='same', name='east_extract_feature_conv_{}'.format(index))(x)
x = layers.BatchNormalization(name='east_extract_feature_bn_{}'.format(index))(x)
x = layers.Activation(activation='relu', name='east_extract_feature_relu_{}'.format(index))(x)
return x
def east(features):
"""
east網(wǎng)絡(luò)頭
:param features: 特征列表: f1, f2, f3, f4分別代表32,16,8,4倍下采樣的特征
:return:
"""
f1, f2, f3, f4 = features
# 特征合并分支
h2 = merge_block(f1, f2, 128, 2)
h3 = merge_block(h2, f3, 64, 3)
h4 = merge_block(h3, f4, 32, 4)
# 提取g4特征
x = layers.Conv2D(32, (3, 3), padding='same', name='east_g4_conv')(h4)
x = layers.BatchNormalization(name='east_g4_bn')(x)
x = layers.Activation(activation='relu', name='east_g4_relu')(x)
# 預(yù)測(cè)得分
predict_score = layers.Conv2D(1, (1, 1), name='predict_score_map')(x)
# 預(yù)測(cè)距離
predict_geo_dist = layers.Conv2D(4, (1, 1), activation='relu', name='predict_geo_dist')(x) # 距離必須大于零
# 預(yù)測(cè)角度
predict_geo_angle = layers.Conv2D(1, (1, 1), name='predict_geo_angle')(x)
return predict_score, predict_geo_dist, predict_geo_angle
def east_net(config, stage='train'):
# 輸入
h, w = list(config.IMAGE_SHAPE)[:2]
h, w = h / 4, w / 4
input_image = Input(shape=config.IMAGE_SHAPE, name='input_image')
input_score_map = Input(shape=(h, w), name='input_score')
input_geo_dist = Input(shape=(h, w, 4), name='input_geo_dist') # rbox 4個(gè)邊距離
input_geo_angle = Input(shape=(h, w), name='input_geo_angle') # rbox 角度
input_mask = Input(shape=(h, w), name='input_mask')
input_image_meta = Input(shape=(12,), name='input_image_meta')
# 網(wǎng)絡(luò)
features = resnet50(input_image)
predict_score, predict_geo_dist, predict_geo_angle = east(features)
if stage == 'train':
# 增加損失函數(shù)層
score_loss = layers.Lambda(lambda x: balanced_cross_entropy(*x), name='score_loss')(
[input_score_map, predict_score, input_mask])
geo_dist_loss = layers.Lambda(lambda x: iou_loss(*x), name='dist_loss')(
[input_geo_dist, predict_geo_dist, input_score_map, input_mask])
geo_angle_loss = layers.Lambda(lambda x: angle_loss(*x), name='angle_loss')(
[input_geo_angle, predict_geo_angle, input_score_map, input_mask])
return Model(inputs=[input_image, input_score_map, input_geo_dist, input_geo_angle, input_mask],
outputs=[score_loss, geo_dist_loss, geo_angle_loss])
else:
# 距離和角度轉(zhuǎn)為頂點(diǎn)坐標(biāo)
vertex = layers.Lambda(lambda x: dist_to_box(*x))([predict_geo_dist, predict_geo_angle])
# dual image_meta
image_meta = layers.Lambda(lambda x: tf.identity(x))(input_image_meta) # 原樣返回
predict_score = layers.Lambda(lambda x: tf.nn.sigmoid(x))(predict_score) # logit轉(zhuǎn)為score
return Model(inputs=[input_image, input_image_meta],
outputs=[predict_score, vertex, image_meta])
def compile(keras_model, config, loss_names=[]):
"""
編譯模型,增加損失函數(shù)幔亥,L2正則化以
:param keras_model:
:param config:
:param loss_names: 損失函數(shù)列表
:return:
"""
# 優(yōu)化目標(biāo)
optimizer = keras.optimizers.SGD(
lr=config.LEARNING_RATE, momentum=config.LEARNING_MOMENTUM,
clipnorm=config.GRADIENT_CLIP_NORM)
# 增加損失函數(shù)开缎,首先清除之前的宋列,防止重復(fù)
keras_model._losses = []
keras_model._per_input_losses = {}
for name in loss_names:
layer = keras_model.get_layer(name)
if layer is None or layer.output in keras_model.losses:
continue
loss = (tf.reduce_mean(layer.output, keepdims=True)
* config.LOSS_WEIGHTS.get(name, 1.))
keras_model.add_loss(loss)
# 增加L2正則化
# 跳過(guò)批標(biāo)準(zhǔn)化層的 gamma 和 beta 權(quán)重
reg_losses = [
keras.regularizers.l2(config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32)
for w in keras_model.trainable_weights
if 'gamma' not in w.name and 'beta' not in w.name]
keras_model.add_loss(tf.add_n(reg_losses))
# 編譯
keras_model.compile(
optimizer=optimizer,
loss=[None] * len(keras_model.outputs)) # 使用虛擬損失
# 為每個(gè)損失函數(shù)增加度量
for name in loss_names:
if name in keras_model.metrics_names:
continue
layer = keras_model.get_layer(name)
if layer is None:
continue
keras_model.metrics_names.append(name)
loss = (
tf.reduce_mean(layer.output, keepdims=True)
* config.LOSS_WEIGHTS.get(name, 1.))
keras_model.metrics_tensors.append(loss)
def add_metrics(keras_model, metric_name_list, metric_tensor_list):
"""
增加度量
:param keras_model: 模型
:param metric_name_list: 度量名稱(chēng)列表
:param metric_tensor_list: 度量張量列表
:return: 無(wú)
"""
for name, tensor in zip(metric_name_list, metric_tensor_list):
keras_model.metrics_names.append(name)
keras_model.metrics_tensors.append(tf.reduce_mean(tensor, keepdims=True))
2. 文本識(shí)別
EAST text detector實(shí)現(xiàn)了文本定位和檢測(cè),下一步需要對(duì)檢測(cè)的文本做識(shí)別處理
將圖像中的文字轉(zhuǎn)化為真正的文本踩晶,就需要用到OCR的技術(shù)攘蔽。OCR領(lǐng)域最著名的孵稽、最主流的開(kāi)源實(shí)現(xiàn)是Tesseract-OCR句狼,鑒于本次識(shí)別的都是印刷體和簡(jiǎn)單的數(shù)字,直接采用google成熟的OCR識(shí)別工具集tesseract-ocr就可以了,尤其是當(dāng)Tesseract-OCR已經(jīng)升級(jí)到了4.0版本相速。和傳統(tǒng)的版本(3.x)比,4.0時(shí)代最突出的變化就是基于LSTM神經(jīng)網(wǎng)絡(luò)鲜锚。
3. 整合成端到端的代碼 end to end
把EAST text detector 和 tesseract-ocr整合到一套代碼中實(shí)現(xiàn)端到端的解決方案,實(shí)現(xiàn)圖片的文字檢測(cè),分割和識(shí)別輸出的一系列操作.