前言
TensorFlow 是谷歌開源的深度學(xué)習(xí)工具包,它將深度學(xué)習(xí)復(fù)雜的計算過程抽象成了數(shù)據(jù)流圖(Data Flow Graph)谢床,并提供簡介靈活的高級抽象接口兄一,讓小白用戶通過簡單的學(xué)習(xí)就可以使用「高大上」的深度學(xué)習(xí)了。當(dāng)谷歌被問到為什么要開源 TensorFlow 時萤悴,他們的回答是:「我們相信機器學(xué)習(xí)是未來創(chuàng)新產(chǎn)品和技術(shù)的關(guān)鍵因素」瘾腰。其實目前我們生活的很多方面已經(jīng)被機器學(xué)習(xí)深切的影響著了,從谷歌搜索到淘寶購物覆履,以及在剛剛結(jié)束的老羅手機發(fā)布會上蹋盆,引起臺下歡呼不斷的語音輸入以及 「BigBang」,背后的核心技術(shù)都是機器學(xué)習(xí)硝全。因此栖雾,我覺得廣大程序員有必要學(xué)習(xí)下機器學(xué)習(xí),以免被很快就會到來的「未來」所拋棄伟众。
安裝
官方的安裝文檔是最詳細(xì)的析藕,以我本人的電腦(MacBook Pro,python 2.7)為例凳厢,安裝過程如下:
// 指定 TensorFlow 地址
export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.11.0rc1-py2-none-any.whl
// 使用 pip 安裝(如果沒有 pip账胧,得先安裝 pip)
sudo pip install --upgrade $TF_BINARY_URL
安裝好后,可以進入 python 交互界面先紫,執(zhí)行以下代碼確認(rèn)是否安裝成功:
$ python
>>> import tensorflow as tf
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
>>> print(sess.run(hello))
Hello, TensorFlow!
>>> a = tf.constant(10)
>>> b = tf.constant(32)
>>> print(sess.run(a + b))
42
簡介
整體工作方式
TensorFlow 最基本的一次計算過程是這樣的:接受 n 個固定格式的數(shù)據(jù)輸入治泥,通過給定的高級函數(shù),轉(zhuǎn)化為 n 個 Tensor 格式的輸出遮精。當(dāng)然一次機器學(xué)習(xí)的過程會有很多次這樣的計算居夹,一次計算的輸出可能是下一次計算的(部分或全部)輸入,TensorFlow 將這一系列的計算過程抽象為了一張數(shù)據(jù)流圖(Data Flow Graph)本冲,如圖:
上圖中從數(shù)據(jù) Input 開始准脂,沿著有向圖進行計算,圖中每個節(jié)點都是一次計算檬洞,稱為 op(option)狸膏,TensorFlow 中數(shù)據(jù)以 Tensor 為格式,輸入一個 Tensor 添怔,經(jīng)過一次 op 后輸出另一個 Tensor湾戳,然后根據(jù)數(shù)據(jù)流圖進入下一個 op 作為輸入闷板,因此,整個計算過程其實是一個 Tensor 數(shù)據(jù)的流動過程院塞,所以谷歌將這個系統(tǒng)形象的叫做 TensorFlow。
有了數(shù)據(jù)流圖后下一個問題是如何在各種設(shè)備上很好的運行性昭,TensorFlow 通過一個會話(Session)來控制整個數(shù)據(jù)流圖的執(zhí)行拦止。TensorFlow 一個很大的優(yōu)點是將復(fù)雜的運算(如矩陣運算,softmax)封裝成了高級函數(shù)糜颠,用戶只要使用就好了汹族,在內(nèi)部,TensorFlow 將這些函數(shù)轉(zhuǎn)化成可以高效在 CPU 或 GPU 執(zhí)行的機器碼其兴。Session 的主要作用是將這張數(shù)據(jù)流圖合理的切分(盡量減少 Session 與 CPU 或 GPU 之間的交互顶瞒,因為很慢),按照一定的順序提交給 CPU 或者 GPU元旬,然后(可能)還進行一些容錯的機制榴徐,總之 Session 就是負(fù)責(zé)高效地讓數(shù)據(jù)流圖被 CPU 或 GPU 執(zhí)行完成的。
如果讀者對 spark 熟悉匀归,看完以上介紹后會不會覺得其實 TensorFlow 的整個工作流程其實跟 Spark 有很多類似之處(我看完文檔之后是這樣認(rèn)為的)坑资。 TensorFlow 的數(shù)據(jù)流圖對應(yīng)于 Spark 的 DAG,TensorFlow 的 Tensor 對應(yīng)于 Spark 的 RDD穆端, TensorFlow 的 Session 對應(yīng)于 Spark 的 SparkContext袱贮。待后面深入學(xué)習(xí)后再分析分析他們的異同。
基本概念
其實很多概念上面已經(jīng)提到体啰,這里統(tǒng)一介紹下攒巍,不過官網(wǎng)的文檔是最詳細(xì)的。
- 數(shù)據(jù)流圖:用來邏輯上描述一次機器學(xué)習(xí)計算的過程荒勇。
- Session:負(fù)責(zé)管理協(xié)調(diào)整個數(shù)據(jù)流圖的計算過程柒莉。
- op:數(shù)據(jù)流圖中的一個節(jié)點,也就是一次基本的操作過程枕屉。
- Tensor:所有 TensorFlow 中計算的數(shù)據(jù)的格式常柄,是一個 n 維的數(shù)組,如 t = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 搀擂。
- Variable:這個在上面沒有提到西潘。之前我們說一個 Tensor 經(jīng)過一次 op 后會轉(zhuǎn)化為另一個 Tensor,完成之后上一步的輸入 Tensor 就會被回收掉哨颂。如果有些數(shù)據(jù)(如模型)我們需要一直保存喷市,每次迭代計算只是改變其值,這時我們就需要 Variable威恼,Variable 本質(zhì)上也是一個 Tensor品姓,只不過他是不會被回收寝并,常駐的 Tensor。
基本使用
關(guān)于 TensorFlow 的基本使用腹备,其實還是建議看官網(wǎng)衬潦,我這邊列一個小的 demo 大致看下一個最基本的 TensorFlow 是怎么寫的。
// 導(dǎo)入 tensorflow
import tensorflow as tf
// 上面說了數(shù)據(jù)是從輸入到 op 再到輸出植酥,用 tf.constant() 生成的是一個源 op镀岛,
// 是一個不需要輸入(輸入已經(jīng)定義在代碼里了),只有輸出的 op友驮。
// matrix1漂羊、matrix2 都是常量 op
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
// 創(chuàng)建一個矩陣乘法 matmul op , 把 'matrix1' 和 'matrix2' 作為輸入.
// 返回值 'product' 代表矩陣乘法的結(jié)果。
product = tf.matmul(matrix1, matrix2)
// 創(chuàng)建 session
sess = tf.Session()
// 使用 session 運算整個數(shù)據(jù)流圖卸留。這里我們可以看到走越,
// 定義整個計算過程的時候,我們是從數(shù)據(jù)輸入的方向上往下定義的耻瑟,
// 而給 session 運行的時候旨指,我們只把最后我們想要結(jié)果的那個方法
// 給了 session,說明 session 會根據(jù)最終的那步喳整,逆向去追溯計算
// 過程淤毛,從而構(gòu)建整個數(shù)據(jù)流圖(猜的)。
result = sess.run(product)
//輸出 [[ 12.]]
print result
//運行結(jié)束算柳,關(guān)閉 session
sess.close()
看到這里低淡,其實會發(fā)現(xiàn) TensorFlow 并不僅僅是一個訓(xùn)練深度神經(jīng)網(wǎng)絡(luò)的工具,它更像是一個計算框架瞬项,其實可以在其基礎(chǔ)上實現(xiàn)各種算法蔗蹋,有朝一日,如果越來越多的開發(fā)者參與其中囱淋,構(gòu)建出更豐富的機器學(xué)習(xí)庫猪杭,替代 spark 也是有可能的(雖然我覺得可能性不高)。
一個簡單的機器學(xué)習(xí)實例
官方文檔 給出了一個利用 MNIST 數(shù)據(jù)集訓(xùn)練圖像識別模型的入門例子妥衣,我這里做一些簡單的分析皂吮。
MNIST 數(shù)據(jù)集包含了一個訓(xùn)練集和一個測試集,集合中包含一些 28 像素 X 28 像素的圖片(這些圖片是手寫的 0-9 的數(shù)字)税手,以及每張圖片的真實數(shù)字蜂筹,我們的目的是通過訓(xùn)練集訓(xùn)練出一個機器學(xué)習(xí)模型,用以預(yù)測任意的手寫輸入圖片所對應(yīng)的數(shù)字芦倒。
文檔中給的入門示例是一個邏輯回歸模型艺挪。輸入是28 X 28 像素的圖片,為了方便兵扬,我們將其轉(zhuǎn)化為一個長度為 28 * 28=784 的數(shù)組麻裳;輸出是用 softmax 回歸后該圖片屬于 0-9 各個數(shù)字的概率數(shù)組口蝠。整個模型通過隨機梯度下降的方法進行訓(xùn)練,詳細(xì)的文檔里都有津坑,我就其代碼簡單分析一下:
// 導(dǎo)入 tensorflow
import tensorflow as tf
// 導(dǎo)入處理 MNIST 數(shù)據(jù)集的工具類
import input_data
// 加載 MNIST 數(shù)據(jù)集妙蔗,獲得一個封裝好的對象 mnist
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
// 設(shè)置輸入(即手寫圖像),placeholder 表示 x是一個占位符疆瑰,
// shape 的 None 表示輸入數(shù)據(jù)的第一維可以任意大小
x = tf.placeholder(tf.float32, shape=[None, 784])
// 設(shè)置針對輸入圖像灭必,其期望的輸出
y_ = tf.placeholder(tf.float32, shape=[None, 10])
// 定義 Variable,用于存儲隱藏層的權(quán)重乃摹,W 是一個二維權(quán)重矩陣,
// W[i][j] 表示第 i 個像素屬于第 j 個數(shù)字的概率跟衅;b 是一維數(shù)組孵睬,b[i]
// 表示第 i 個數(shù)字的偏移量。因此伶跷,要求輸入 x 屬于各個數(shù)字的概率
// 公式為:W * x + b
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
/** 定義好需要的參數(shù)變量后掰读,就可以設(shè)置訓(xùn)練過程了,訓(xùn)練過程主要分為 3 步:
1. 定義隱藏層的輸入輸出過程
2. 定義損失函數(shù)
3. 選擇訓(xùn)練方法開始訓(xùn)練 **/
// 1.定義隱藏層的輸入輸出過程 :
// 之前我們說用 softmax 回歸模型來做隱藏層叭莫,TensorFlow 已經(jīng)實現(xiàn)了
// softmax 的具體方法蹈集,所以我們只要一行代碼就能表示整個前饋的過程
y = tf.nn.softmax(tf.matmul(x, W) + b)
// 2.定義損失函數(shù):
// 我們使用交叉熵來衡量結(jié)果的好壞
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
// 3.選擇訓(xùn)練方法開始訓(xùn)練:
// 由于我們已經(jīng)知道了損失函數(shù),我們的訓(xùn)練目的是讓損失函數(shù)最小雇初,
// 這里我們使用梯度下降的方法求最小值
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
// 定義好訓(xùn)練過程后拢肆,就可以開始真正的訓(xùn)練過程了
// 初始化 session
sess = tf.Session()
//加載所有 variable
init = tf.initialize_all_variables()
sess.run(init)
// 使用隨機梯度下降的方法,分批次多次訓(xùn)練
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
//這里隨機獲取的 batch_xs, batch_ys 用來填充之前定義的占位符 x, y_
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
我們可以看到靖诗,只要告訴 TensorFlow 待訓(xùn)練參數(shù)郭怪、損失函數(shù)、具體的訓(xùn)練方法刊橘,區(qū)區(qū)三行代碼鄙才,它就能自動地進行訓(xùn)練出一個圖像識別模型。當(dāng)然促绵,這只是最簡單的一個邏輯回歸模型攒庵,要想獲得好效果,需要用到 卷積神經(jīng)網(wǎng)絡(luò)(CNN)败晴,等后面深入學(xué)習(xí)后再跟讀者分享浓冒。
總結(jié)
本文簡單介紹了 TensorFlow 的一些基本的概念和工作方式,后面有精力的話再深入的學(xué)習(xí)下尖坤。另外裆蒸,由于自己對機器學(xué)習(xí)不是很熟悉,對 TensorFlow 也是剛接觸糖驴,所以文中可能會有比較多低級錯誤僚祷,望讀者看到后指出佛致。