預(yù)測實現(xiàn)
在上一篇文章中實現(xiàn)了裝甲板id識別的網(wǎng)絡(luò)訓(xùn)練并保存為了ckpt文件
http://www.reibang.com/p/191337a9a819
雖然全連接的網(wǎng)絡(luò)精度也就那樣了礁蔗,但是還是練習(xí)一下用現(xiàn)有的網(wǎng)絡(luò)進行裝甲板id預(yù)測
- 復(fù)現(xiàn)網(wǎng)絡(luò)
#網(wǎng)絡(luò)搭建
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(500,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(128,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(50,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(8,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())
])
- 加載參數(shù)
#加載參數(shù)
ckpt_path = "./checkpoint/armor_id.ckpt"
if(os.path.exists(ckpt_path + ".index")):
print("--load modle--")
model.load_weights(ckpt_path)
else:
print('----------------------------------------------error')
- 輸入數(shù)據(jù)處理
#圖片讀取與處理
img = tf.io.read_file (test_img_path)
img_raw = tf.image.decode_bmp (img)
img_raw = tf.cast(img_raw,dtype=tf.float32)
x_predict = tf.convert_to_tensor(img_raw)
x_predict = tf.reshape(x_predict,[1,-1])
- 代碼整體實現(xiàn)
import tensorflow as tf
import os
if __name__ == '__main__':
test_img_path = './armor_dataset/8/8_47.bmp'
#網(wǎng)絡(luò)搭建
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(500,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(128,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(50,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(8,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())
])
#加載參數(shù)
ckpt_path = "./checkpoint/armor_id.ckpt"
if(os.path.exists(ckpt_path + ".index")):
print("--load modle--")
model.load_weights(ckpt_path)
else:
print('----------------------------------------------error')
#圖片讀取與處理
img = tf.io.read_file (test_img_path)
img_raw = tf.image.decode_bmp (img)
img_raw = tf.cast(img_raw,dtype=tf.float32)
x_predict = tf.convert_to_tensor(img_raw)
x_predict = tf.reshape(x_predict,[1,-1])
#預(yù)測結(jié)果
result = model.predict(x_predict)
pred = tf.argmax(result,axis=1) #獲取概率最大數(shù)值的下標
pred = pred + 1
print("預(yù)測id為:")
tf.print(pred)
遇到的坑
- 實際未讀入數(shù)據(jù)
現(xiàn)象是每次輸出結(jié)果隨機變化 - 使用tfrecord解碼的數(shù)據(jù)和使用原始數(shù)據(jù)解碼的數(shù)據(jù)不一致
應(yīng)當檢查編碼解碼過程中的類型轉(zhuǎn)換
http://www.reibang.com/p/51659ec687f8
測試結(jié)果
測試圖片
51.png
5177.png
847.png
還進行了其他數(shù)字的測試
測試圖片基本都實現(xiàn)了正確的預(yù)測