本文改編自TensorFLow官方教程中文版遍坟,力求更加簡潔、清晰晴股。
一政鼠、介紹
TensorFlow是當(dāng)前最流行的機(jī)器學(xué)習(xí)框架,有了它队魏,開發(fā)人工智能程序就像Java編程一樣簡單公般。今天,就讓我們從手寫體數(shù)字識(shí)別入手胡桨,看看如何用機(jī)器學(xué)習(xí)的方法解決這個(gè)問題官帘。
二、編程環(huán)境
Python2.7+TensorFlow0.5.0下測試通過昧谊,Python3.5下未測試刽虹。請(qǐng)參考《TensorFLow下載與安裝》配置環(huán)境。
三呢诬、思路
沒有接觸過圖像處理的人可能會(huì)很納悶涌哲,從一張圖片識(shí)別出里面的內(nèi)容似乎是件相當(dāng)神奇的事情。其實(shí)尚镰,當(dāng)你把圖片當(dāng)成一枚枚像素來看的話阀圾,就沒那么神秘了。下圖為手寫體數(shù)字1的圖片狗唉,它在計(jì)算機(jī)中的存儲(chǔ)其實(shí)是一個(gè)二維矩陣初烘,每個(gè)元素都是0~1之間的數(shù)字,0代表白色分俯,1代表黑色肾筐,小數(shù)代表某種程度的灰色。
現(xiàn)在缸剪,對(duì)于MNIST數(shù)據(jù)集中的圖片來說吗铐,我們只要把它當(dāng)成長度為784的向量就可以了(忽略它的二維結(jié)構(gòu),28×28=784)杏节。我們的任務(wù)就是讓這個(gè)向量經(jīng)過一個(gè)函數(shù)后輸出一個(gè)類別唬渗,吶典阵,就是下邊這個(gè)函數(shù),稱為Softmax分類器谣妻。
這個(gè)式子里的圖片向量的長度只有3萄喳,用x表示。乘上一個(gè)系數(shù)矩陣W蹋半,再加上一個(gè)列向量b他巨,然后輸入softmax函數(shù),輸出就是分類結(jié)果y减江。W是一個(gè)權(quán)重矩陣染突,W的每一行與整個(gè)圖片像素相乘的結(jié)果是一個(gè)分?jǐn)?shù)score,分?jǐn)?shù)越高表示圖片越接近該行代表的類別辈灼。因此份企,W x + b 的結(jié)果其實(shí)是一個(gè)列向量,每一行代表圖片屬于該類的評(píng)分巡莹。熟悉圖像分類的同學(xué)應(yīng)該了解司志,通常分類的結(jié)果并非評(píng)分,而是概率降宅,表示有多大的概率屬于此類別骂远。因此,Softmax函數(shù)的作用就是把評(píng)分轉(zhuǎn)換成概率腰根,并使總的概率為1激才。
有了這個(gè)模型,如何訓(xùn)練它呢额嘿?
對(duì)于機(jī)器學(xué)習(xí)算法來說瘸恼,訓(xùn)練就是不斷調(diào)整模型參數(shù)使誤差達(dá)到最小的過程。這里的模型參數(shù)就是W和b册养。接下來我們需要定義誤差东帅。誤差當(dāng)然是把預(yù)測的結(jié)果y和正確結(jié)果相比較得到的,但是由于正確結(jié)果是one_hot向量(即只有一個(gè)元素是1捕儒,其它元素都是0)冰啃,而預(yù)測結(jié)果是個(gè)概率向量,用什么方法比較其實(shí)是個(gè)需要深入考慮的事情刘莹。事實(shí)上,我們使用的是交叉熵?fù)p失(cross-entropy loss)焚刚,為什么用這個(gè)点弯,其實(shí)我現(xiàn)在也不太清楚,所以姑且先用著吧矿咕,以后見得多了自然就明白了抢肛。
好了狼钮,到這里思路大體上就講完了,還有不清楚的地方讓我們看看代碼就能理解了捡絮。
四熬芜、TensorFlow實(shí)現(xiàn)
說實(shí)話,這個(gè)代碼比想象中還要簡練福稳,只有33行涎拉,所以我把它直接貼出來。
# coding=utf-8
import tensorflow as tf
import input_data
# 下載MNIST數(shù)據(jù)集到'MNIST_data'文件夾并解壓
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 設(shè)置權(quán)重weights和偏置biases作為優(yōu)化變量的圆,初始值設(shè)為0
weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([10]))
# 構(gòu)建模型
x = tf.placeholder("float", [None, 784])
y = tf.nn.softmax(tf.matmul(x, weights) + biases) # 模型的預(yù)測值
y_real = tf.placeholder("float", [None, 10]) # 真實(shí)值
cross_entropy = -tf.reduce_sum(y_real * tf.log(y)) # 預(yù)測值與真實(shí)值的交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # 使用梯度下降優(yōu)化器最小化交叉熵
# 開始訓(xùn)練
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100) # 每次隨機(jī)選取100個(gè)數(shù)據(jù)進(jìn)行訓(xùn)練鼓拧,即所謂的“隨機(jī)梯度下降(Stochastic Gradient Descent,SGD)”
sess.run(train_step, feed_dict={x: batch_xs, y_real:batch_ys}) # 正式執(zhí)行train_step越妈,用feed_dict的數(shù)據(jù)取代placeholder
if i % 100 == 0:
# 每訓(xùn)練100次后評(píng)估模型
correct_prediction = tf.equal(tf.argmax(y, 1), tf.arg_max(y_real, 1)) # 比較預(yù)測值和真實(shí)值是否一致
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) # 統(tǒng)計(jì)預(yù)測正確的個(gè)數(shù)季俩,取均值得到準(zhǔn)確率
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_real: mnist.test.labels})
這里用到了官方給的一個(gè)代碼文件input_data
,我已經(jīng)放到工程里了梅掠。導(dǎo)入input_data
酌住,就可以用它來讀取MNIST數(shù)據(jù)集,非常方便阎抒。
整體來說酪我,使用TensorFLow編程主要分為兩個(gè)階段,第一個(gè)階段是構(gòu)建模型挠蛉,把網(wǎng)絡(luò)模型用代碼搭建起來祭示。TensorFlow的本質(zhì)是數(shù)據(jù)流圖,因此這一階段其實(shí)是在規(guī)定數(shù)據(jù)的流動(dòng)方向谴古。第二個(gè)階段是開始訓(xùn)練质涛,把數(shù)據(jù)輸入到模型中,并通過梯度下降等方法優(yōu)化變量的值掰担。
首先汇陆,我們需要把權(quán)重weights和偏置biases設(shè)置成優(yōu)化變量,只有優(yōu)化變量才可以在后面被Optimizer優(yōu)化带饱。并且需要為它們賦初值毡代,這里將weights設(shè)為784×10的zero矩陣,把biases設(shè)為1×10的zero矩陣勺疼。
然后構(gòu)建模型教寂。模型的輸入一般設(shè)置為placeholder,譯為占位符执庐。在訓(xùn)練的過程中只有placeholder可以允許數(shù)據(jù)輸入酪耕。第一維的長度為None表示允許輸入任意長度,也就是說輸入可以是任意張圖像轨淌。
使用tf.log
計(jì)算y
中每個(gè)元素的對(duì)數(shù)迂烁,并逐個(gè)與y_real
相乘看尼,再求和并取反,就得到了交叉熵盟步。使用梯度下降優(yōu)化器最小化交叉熵作為訓(xùn)練步驟train_step
藏斩。
接下來開始訓(xùn)練。首先要調(diào)用tf.initialize_all_variables()
方法初始化所有變量却盘。再創(chuàng)建一個(gè)tf.Session
對(duì)象來控制整個(gè)訓(xùn)練流程狰域。循環(huán)訓(xùn)練1000次,每次從訓(xùn)練集中隨機(jī)取100個(gè)數(shù)據(jù)進(jìn)行訓(xùn)練谷炸。
在訓(xùn)練的過程中北专,每隔100次對(duì)模型進(jìn)行一次評(píng)估。評(píng)估使用測試集數(shù)據(jù)旬陡,統(tǒng)計(jì)正確預(yù)測的個(gè)數(shù)的百分比并輸出拓颓。結(jié)果如下:
$ /usr/bin/python2.7 /home/wjg/projects/MNISTRecognition/main.py
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.4075
0.894
0.8989
0.9012
0.904
0.9105
0.9086
0.9137
0.9105
0.9174
Process finished with exit code 0
可見預(yù)測準(zhǔn)確率逐漸上升,最后達(dá)到91%描孟。
五驶睦、總結(jié)
這是我第一次使用TensorFlow,它給我的感覺是非常方便匿醒,很貼合程序員的開發(fā)習(xí)慣场航。相比之下,之前用Caffe的時(shí)候就總是摸不著頭腦廉羔。當(dāng)然也可能是因?yàn)門ensorFlow的官方文檔更友好的緣故溉痢。
本文在很多地方都語焉不詳,因?yàn)樽髡咚接邢薇锼嘘P(guān)深?yuàn)W的數(shù)學(xué)原理都一帶而過孩饼。所以如果想要深入了解,還是推薦大家看官方教程竹挡。文末的參考資料一欄列出了一些有幫助的文章和視頻镀娶。
最后,可以從我的GitHub上下載完整代碼:https://github.com/jingedawang/MNISTRecognition
另外揪罕,熟悉多維矩陣操作(NumPy中的切片和廣播)可以更好的地理解代碼梯码,建議閱讀參考資料最后一條:P
六、參考資料
MNIST機(jī)器學(xué)習(xí)入門 TensorFlow中文社區(qū)
莫煩 Tensorflow 16 Classification 分類學(xué)習(xí) 莫煩
Classification 分類學(xué)習(xí) 莫煩
Softmax 函數(shù)的特點(diǎn)和作用是什么好啰? 知乎
CS231n課程筆記翻譯:線性分類筆記(下) 杜客譯
CS231n課程筆記翻譯:Python Numpy教程 杜客譯