Caffe2 玩玩回歸(Toy Regression)[5]

前言

這一節(jié)將講述如何使用Caffe2的特征進(jìn)行簡單的線性回歸學(xué)習(xí)。主要分為以下幾步:
- 生成隨機(jī)數(shù)據(jù)作為模型的輸入
- 用這些數(shù)據(jù)創(chuàng)建網(wǎng)絡(luò)
- 自動訓(xùn)練模型
- 查看梯度遞減的結(jié)果和學(xué)習(xí)過程中網(wǎng)絡(luò)參數(shù)的變化
ipython notebook教程請看這里
譯者注:如果圖片看不清树肃,可以保存到本地查看甸赃。

這是一個快速的例子,展示如何使用前面的基礎(chǔ)教程進(jìn)行快速的嘗試用CNN進(jìn)行回歸趟据。我們要解決的問題非常簡單券犁,輸入是二維的x,輸出是一維的y,權(quán)重w=[2.0,1.5]汹碱,偏置b=0.5粘衬。所以生成ground truth的等式是y=wx+b

在這個教程中,我們將會使用Caffe2的op生成訓(xùn)練數(shù)據(jù)稚新。注意勘伺,這和你日常訓(xùn)練工作不同:在真實的訓(xùn)練中,訓(xùn)練數(shù)據(jù)一般從外部源載入枷莉,比如Caffe的DB數(shù)據(jù)庫娇昙,或者Hive表。我們將會在MNIST的例程中講到笤妙。

這個例程中冒掌,每一個Caffe2 的op將會寫得非常詳細(xì),所以會顯得太多繁雜蹲盘。但是在MNIST例程中股毫,我們將使用CNN模型助手來構(gòu)建CNN模型。

from caffe2.python import core, cnn, net_drawer, workspace, visualize
import numpy as np
from IPython import display
from matplotlib import pyplot

聲明計算圖

這里召衔,我們聲明兩個圖:一個用于初始化計算中將會用到的變量參數(shù)和常量铃诬,另外一個作為主圖將會用于跑起梯度下降,也就是訓(xùn)練苍凛。(譯者注:不明白為啥叫做計算圖(computation graphs)趣席,其實看代碼和前一個教程的一樣,就是創(chuàng)建兩個net醇蝴,一個用于初始化參數(shù)宣肚,一個用于訓(xùn)練。)

首先悠栓,初始化網(wǎng)絡(luò):網(wǎng)絡(luò)的名字不重要霉涨。我們基本上把初始化代碼放在一個net中,這樣惭适,我們就可以調(diào)用RunNetOnce()函數(shù)來執(zhí)行笙瑟。我們分離init_net的原因是,這些操作在整個訓(xùn)練的過程中只需要執(zhí)行一次癞志。

init_net = core.Net("init")
# ground truth 參數(shù).
W_gt = init_net.GivenTensorFill( [], "W_gt", shape=[1, 2], values=[2.0, 1.5])
B_gt = init_net.GivenTensorFill([], "B_gt", shape=[1], values=[0.5])
# Constant value ONE is used in weighted sum when updating parameters.
ONE = init_net.ConstantFill([], "ONE", shape=[1], value=1.)
# ITER是迭代的次數(shù).
ITER = init_net.ConstantFill([], "ITER", shape=[1], value=0, dtype=core.DataType.INT32)

# 隨機(jī)初始化權(quán)重往枷,范圍在[-1,1],初始化偏置為0
W = init_net.UniformFill([], "W", shape=[1, 2], min=-1., max=1.)
B = init_net.ConstantFill([], "B", shape=[1], value=0.0)
print('Created init net.')

上面代碼創(chuàng)建并初始化了init_net網(wǎng)絡(luò)今阳。主訓(xùn)練網(wǎng)絡(luò)如下师溅,我們展示了創(chuàng)建的的每一步。
- 前向傳播產(chǎn)生loss
- 通過自動微分進(jìn)行后向傳播
- 使用標(biāo)準(zhǔn)的SGD進(jìn)行參數(shù)更新

train_net = core.Net("train")
# First, 生成隨機(jī)的樣本X和創(chuàng)建ground truth.
X = train_net.GaussianFill([], "X", shape=[64, 2], mean=0.0, std=1.0, run_once=0)
Y_gt = X.FC([W_gt, B_gt], "Y_gt")
# 往ground truth添加高斯噪聲
noise = train_net.GaussianFill([], "noise", shape=[64, 1], mean=0.0, std=1.0, run_once=0)
Y_noise = Y_gt.Add(noise, "Y_noise")
#注意到不需要講梯度信息傳播到 Y_noise層,
#所以使用StopGradient 函數(shù)告訴偏微分算法不需要做這一步
Y_noise = Y_noise.StopGradient([], "Y_noise")

# 線性回歸預(yù)測
Y_pred = X.FC([W, B], "Y_pred")

# 使用歐拉損失并對batch進(jìn)行平均
dist = train_net.SquaredL2Distance([Y_noise, Y_pred], "dist")
loss = dist.AveragedLoss([], ["loss"])

現(xiàn)在讓我們看看網(wǎng)絡(luò)是什么樣子的盾舌。從下面的圖可以看到,主要包含四部分蘸鲸。
- 隨機(jī)生成X
- 使用W_gt,B_gtFC操作生成grond truth Y_gt
- 使用當(dāng)前的參數(shù)W和B進(jìn)行預(yù)測
- 比較輸出和計算損失

graph = net_drawer.GetPydotGraph(train_net.Proto().op, "train", rankdir="LR")
display.Image(graph.create_png(), width=800)

現(xiàn)在妖谴,和其他框架相似,Caffe2允許我們自動地生成梯度操作,讓我們試一下膝舅,并看看計算圖有什么變化嗡载。

# Get gradients for all the computations above.
gradient_map = train_net.AddGradientOperators([loss])
graph = net_drawer.GetPydotGraph(train_net.Proto().op, "train", rankdir="LR")
display.Image(graph.create_png(), width=800)

一旦我們獲得參數(shù)的梯度,我們就可以將進(jìn)行SGD操作:獲得當(dāng)前step的學(xué)習(xí)率仍稀,更參數(shù)洼滚。在這個例子中,我們沒有做任何復(fù)雜的操作技潘,只是簡單的SGD遥巴。

# 迭代數(shù)增加1.
train_net.Iter(ITER, ITER)
# 根據(jù)迭代數(shù)計算學(xué)習(xí)率.
LR = train_net.LearningRate(ITER, "LR", base_lr=-0.1, policy="step", stepsize=20, gamma=0.9)
# 權(quán)重求和
train_net.WeightedSum([W, ONE, gradient_map[W], LR], W)
train_net.WeightedSum([B, ONE, gradient_map[B], LR], B)

graph = net_drawer.GetPydotGraph(train_net.Proto().op, "train", rankdir="LR")
display.Image(graph.create_png(), width=800)

再次展示計算圖



既然我們創(chuàng)建了網(wǎng)絡(luò),那么跑起來

workspace.RunNetOnce(init_net)
workspace.CreateNet(train_net)

在我們開始訓(xùn)練之前享幽,先來看看參數(shù):

print("Before training, W is: {}".format(workspace.FetchBlob("W")))
print("Before training, B is: {}".format(workspace.FetchBlob("B")))

參數(shù)初始化如下

Before training, W is: [[-0.77634162 -0.88467366]]
Before training, B is: [ 0.]

訓(xùn)練:

for i in range(100):
    workspace.RunNet(train_net.Proto().name)

迭代100次后铲掐,查看參數(shù):

print("After training, W is: {}".format(workspace.FetchBlob("W")))
print("After training, B is: {}".format(workspace.FetchBlob("B")))

print("Ground truth W is: {}".format(workspace.FetchBlob("W_gt")))
print("Ground truth B is: {}".format(workspace.FetchBlob("B_gt")))

參數(shù)如下:

After training, W is: [[ 1.95769441  1.47348857]]
After training, B is: [ 0.45236012]
Ground truth W is: [[ 2.   1.5]]
Ground truth B is: [ 0.5]

看起來相當(dāng)簡單是不是?讓我們再近距離看看訓(xùn)練過程中參數(shù)的更新過程值桩。為此摆霉,我們重新初始化參數(shù),看看每次迭代參數(shù)的變化奔坟。記住携栋,我們可以在任何時候從workspace中取出我們的blobs。

workspace.RunNetOnce(init_net)
w_history = []
b_history = []
for i in range(50):
    workspace.RunNet(train_net.Proto().name)
    w_history.append(workspace.FetchBlob("W"))
    b_history.append(workspace.FetchBlob("B"))
w_history = np.vstack(w_history)
b_history = np.vstack(b_history)
pyplot.plot(w_history[:, 0], w_history[:, 1], 'r')
pyplot.axis('equal')
pyplot.xlabel('w_0')
pyplot.ylabel('w_1')
pyplot.grid(True)
pyplot.figure()
pyplot.plot(b_history)
pyplot.xlabel('iter')
pyplot.ylabel('b')
pyplot.grid(True)

你可以發(fā)現(xiàn)非常典型的批梯度下降表現(xiàn):由于噪聲的影響咳秉,訓(xùn)練過程中存在波動婉支。在Ipython notebook中跑多幾次這個案例,你將會看到不同的初始化和噪聲的影響滴某。
當(dāng)然磅摹,這只是一個玩玩的例子,在MNIST例程中霎奢,我們將會看到一個更加真實的CNN訓(xùn)練的例子户誓。

譯者注: 轉(zhuǎn)載請注明出處:http://www.reibang.com/c/cf07b31bb5f2

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市幕侠,隨后出現(xiàn)的幾起案子帝美,更是在濱河造成了極大的恐慌,老刑警劉巖晤硕,帶你破解...
    沈念sama閱讀 222,590評論 6 517
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件悼潭,死亡現(xiàn)場離奇詭異,居然都是意外死亡舞箍,警方通過查閱死者的電腦和手機(jī)舰褪,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 95,157評論 3 399
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來疏橄,“玉大人占拍,你說我怎么就攤上這事略就。” “怎么了晃酒?”我有些...
    開封第一講書人閱讀 169,301評論 0 362
  • 文/不壞的土叔 我叫張陵表牢,是天一觀的道長。 經(jīng)常有香客問我贝次,道長崔兴,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 60,078評論 1 300
  • 正文 為了忘掉前任蛔翅,我火速辦了婚禮敲茄,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘搁宾。我一直安慰自己折汞,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 69,082評論 6 398
  • 文/花漫 我一把揭開白布盖腿。 她就那樣靜靜地躺著爽待,像睡著了一般。 火紅的嫁衣襯著肌膚如雪翩腐。 梳的紋絲不亂的頭發(fā)上鸟款,一...
    開封第一講書人閱讀 52,682評論 1 312
  • 那天,我揣著相機(jī)與錄音茂卦,去河邊找鬼何什。 笑死,一個胖子當(dāng)著我的面吹牛等龙,可吹牛的內(nèi)容都是我干的处渣。 我是一名探鬼主播,決...
    沈念sama閱讀 41,155評論 3 422
  • 文/蒼蘭香墨 我猛地睜開眼蛛砰,長吁一口氣:“原來是場噩夢啊……” “哼罐栈!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起泥畅,我...
    開封第一講書人閱讀 40,098評論 0 277
  • 序言:老撾萬榮一對情侶失蹤荠诬,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后位仁,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體柑贞,經(jīng)...
    沈念sama閱讀 46,638評論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,701評論 3 342
  • 正文 我和宋清朗相戀三年聂抢,在試婚紗的時候發(fā)現(xiàn)自己被綠了钧嘶。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,852評論 1 353
  • 序言:一個原本活蹦亂跳的男人離奇死亡琳疏,死狀恐怖康辑,靈堂內(nèi)的尸體忽然破棺而出摄欲,到底是詐尸還是另有隱情轿亮,我是刑警寧澤疮薇,帶...
    沈念sama閱讀 36,520評論 5 351
  • 正文 年R本政府宣布,位于F島的核電站我注,受9級特大地震影響按咒,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜但骨,卻給世界環(huán)境...
    茶點故事閱讀 42,181評論 3 335
  • 文/蒙蒙 一励七、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧奔缠,春花似錦掠抬、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,674評論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至闷哆,卻和暖如春腰奋,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背抱怔。 一陣腳步聲響...
    開封第一講書人閱讀 33,788評論 1 274
  • 我被黑心中介騙來泰國打工劣坊, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人屈留。 一個月前我還...
    沈念sama閱讀 49,279評論 3 379
  • 正文 我出身青樓局冰,卻偏偏與公主長得像,于是被迫代替她去往敵國和親灌危。 傳聞我的和親對象是個殘疾皇子康二,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,851評論 2 361

推薦閱讀更多精彩內(nèi)容