最近看到一個巨牛的人工智能教程,分享一下給大家豆同。教程不僅是零基礎(chǔ)番刊,通俗易懂,而且非常風(fēng)趣幽默影锈,像看小說一樣芹务!覺得太牛了,所以分享給大家鸭廷。平時碎片時間可以當(dāng)小說看枣抱,【點(diǎn)這里可以去膜拜一下大神的“小說”】。
1. 下載網(wǎng)絡(luò)結(jié)構(gòu)及模型
1.1 下載MobileNet V1定義網(wǎng)絡(luò)結(jié)構(gòu)的文件
MobileNet V1的網(wǎng)絡(luò)結(jié)構(gòu)可以直接從官方Github庫中下載定義網(wǎng)絡(luò)結(jié)構(gòu)的文件辆床,地址為:https://raw.githubusercontent.com/tensorflow/models/master/research/slim/nets/mobilenet_v1.py
1.2 下載MobileNet V1預(yù)訓(xùn)練模型
MobileNet V1預(yù)訓(xùn)練的模型文在如下地址中下載:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md
打開以上網(wǎng)址佳晶,可以看到MobileNet V1官方預(yù)訓(xùn)練的模型,官方提供了不同輸入尺寸和不同網(wǎng)絡(luò)中通道數(shù)的多個模型讼载,并且提供了每個模型對應(yīng)的精度轿秧。可以根據(jù)實(shí)際的需要下載對應(yīng)的模型咨堤,如下圖所示菇篡。
這里以選擇MobileNet_v1_1.0_192為例,表示網(wǎng)絡(luò)中的所有卷積后的通道數(shù)為標(biāo)準(zhǔn)通道數(shù)(即1.0倍)一喘,輸入圖像尺寸為192X192驱还。
2. 構(gòu)建網(wǎng)絡(luò)結(jié)構(gòu)及加載模型參數(shù)
2.1 構(gòu)建網(wǎng)絡(luò)結(jié)構(gòu)
在1.1小節(jié)中下載mobilenet_v1.py文件后,使用其中的mobilenet_v1函數(shù)構(gòu)建網(wǎng)絡(luò)結(jié)構(gòu)靜態(tài)圖凸克,如下代碼所示议蟆。
import tensorflow as tf
from mobilenet_v1 import mobilenet_v1,mobilenet_v1_arg_scope
slim = tf.contrib.slim
def build_model(inputs):
with slim.arg_scope(mobilenet_v1_arg_scope(is_training=False)):
logits, end_points = mobilenet_v1(inputs, is_training=False, depth_multiplier=1.0, num_classes=1001)
scores = end_points['Predictions']
print(scores)
#取概率最大的3個類別及其對應(yīng)概率
output = tf.nn.top_k(scores, k=3, sorted=True)
#indices為類別索引,values為概率值
return output.indices,output.values
上面代碼中萎战,使用函數(shù)tf.nn.top_k取概率最大的3個類別機(jī)器對應(yīng)概率咐容。
2.2 加載模型參數(shù)
CKPT = 'mobilenet_v1_1.0_192.ckpt'
def load_model(sess):
loader = tf.train.Saver()
loader.restore(sess,CKPT)
inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3))
classes_tf,scores_tf = build_model(inputs)
with tf.Session() as sess:
load_model(sess)
先定義placeholder輸入inputs,再通過函數(shù)build_model完成靜態(tài)圖的定義撞鹉。接下來傳入tf.Session對象到load_model函數(shù)中完成模型加載疟丙。
3. 模型測試
3.1 加載Label
網(wǎng)絡(luò)輸出結(jié)果為類別的索引值颖侄,需要將索引值轉(zhuǎn)為對應(yīng)的類別字符串。先從官網(wǎng)下載label數(shù)據(jù)享郊,需要注意的是MobileNet V1使用的是ILSVRC-2012-CLS數(shù)據(jù)览祖,因此需要下載對應(yīng)的Label信息(本文后面附件中會提供)。解析Label數(shù)據(jù)代碼如下炊琉。
def load_label():
label=['其他']
with open('label.txt','r',encoding='utf-8') as r:
lines = r.readlines()
for l in lines:
l = l.strip()
arr = l.split(',')
label.append(arr[1])
return label
3.2 測試結(jié)果
使用如下圖片進(jìn)行測試展蒂。
執(zhí)行inference.py后,控制臺輸出結(jié)果如下所示苔咪。
識別 test_images/test1.png 結(jié)果如下:
No. 0 類別: 軍用飛機(jī) 概率: 0.9363691
No. 1 類別: 飛機(jī)翅膀 概率: 0.032617383
No. 2 類別: 炮彈 概率: 0.01853972
識別 test_images/test2.png 結(jié)果如下:
No. 0 類別: 小兒床 概率: 0.9455737
No. 1 類別: 搖籃 概率: 0.044925883
No. 2 類別: 板架 概率: 0.007288801
4 完整代碼
inference.py完整的代碼如下所示锰悼。
import tensorflow as tf
from mobilenet_v1 import mobilenet_v1,mobilenet_v1_arg_scope
import cv2
import os
import numpy as np
slim = tf.contrib.slim
CKPT = 'mobilenet_v1_1.0_192.ckpt'
dir_path = 'test_images'
def build_model(inputs):
with slim.arg_scope(mobilenet_v1_arg_scope(is_training=False)):
logits, end_points = mobilenet_v1(inputs, is_training=False, depth_multiplier=1.0, num_classes=1001)
scores = end_points['Predictions']
print(scores)
#取概率最大的5個類別及其對應(yīng)概率
output = tf.nn.top_k(scores, k=3, sorted=True)
#indices為類別索引,values為概率值
return output.indices,output.values
def load_model(sess):
loader = tf.train.Saver()
loader.restore(sess,CKPT)
def get_data(path_list,idx):
img_path = images_path[idx]
img = cv2.imread(img_path)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cv2.resize(img,(192,192))
img = np.expand_dims(img,axis=0)
img = (img/255.0-0.5)*2.0
return img_path,img
def load_label():
label=['其他']
with open('label.txt','r',encoding='utf-8') as r:
lines = r.readlines()
for l in lines:
l = l.strip()
arr = l.split(',')
label.append(arr[1])
return label
inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3))
classes_tf,scores_tf = build_model(inputs)
images_path =[dir_path+'/'+n for n in os.listdir(dir_path)]
label=load_label()
with tf.Session() as sess:
load_model(sess)
for i in range(len(images_path)):
path,img = get_data(images_path,i)
classes,scores = sess.run([classes_tf,scores_tf],feed_dict={inputs:img})
print('\n識別',path,'結(jié)果如下:')
for j in range(3):#top 3
idx = classes[0][j]
score=scores[0][j]
print('\tNo.',j,'類別:',label[idx],'概率:',score)