PyTorch中的backward

注意:簡書數(shù)學公式支持不好蜓堕,建議移步我的博客獲得更佳的閱讀體驗。

接觸了PyTorch這么長的時間年枕,也玩了很多PyTorch的騷操作标捺,都特別簡單直觀地實現(xiàn)了懊纳,但是有一個網(wǎng)絡(luò)訓練過程中的操作之前一直沒有仔細去考慮過,那就是loss.backward()亡容,看到這個大家一定都很熟悉嗤疯,loss是網(wǎng)絡(luò)的損失函數(shù),是一個標量闺兢,你可能會說這不就是反向傳播嗎茂缚,有什么好講的。

但是不知道大家思考過沒有,如果loss不是一個標量阱佛,而是一個向量帖汞,那么loss.backward()是什么結(jié)果呢?

大家可以去試試凑术,寫一個簡單的小程序

import torch as t
from torch.autograd import Variable as v
x = v(t.ones(2, 2), requires_grad=True)
y = x + 1
y.backward()

運行一下程序,恭喜你報錯了所意,錯誤顯示如下

backwarderror.png

我們來讀一讀這個錯誤是什么意思淮逊。backward只能被應(yīng)用在一個標量上,也就是一個一維tensor扶踊,或者傳入跟變量相關(guān)的梯度泄鹏。

嗯,前面一句話很簡單秧耗,backward應(yīng)用在一個標量备籽,平時我們也是這么使用的,但是后面一句話分井,with gradient w.r.t variable是什么鬼车猬,傳入一個變量相關(guān)的梯度。不理解啊不理解尺锚,看不懂沒關(guān)系我們還可以做實驗來解決這個問題珠闰,俗話說自己動手豐衣足食(我也想做個伸手黨去看看別人寫的,然后不幸地是并沒有什么人寫過這方面的東西)瘫辩。

首先我們開始做一個簡單的實驗伏嗜,就是復(fù)習一下標量的形式

# simple gradient
a = v(t.FloatTensor([2, 3]), requires_grad=True)
b = a + 3
c = b * b * 3
out = c.mean()
out.backward()
print('*'*10)
print('=====simple gradient======')
print('input')
print(a.data)
print('compute result is')
print(out.data[0])
print('input gradients are')
print(a.grad.data)

很簡單,我們把數(shù)學表達式寫出來伐厌,傳入的參數(shù)$x_1 = 2, x_2 = 3$承绸,特別注意Variable里面默認的參數(shù)requires_grad=False,所以這里我們要重新傳入requires_grad=True讓它成為一個葉子節(jié)點挣轨。
$$
a = (x_1, x_2) \quad b = (x_1 + 3, x_2 + 3) \quad c = (3 * (x_1+3)^2, 3(x_2 + 3)^2) \quad out=\frac{3((x_1+3)^2 + (x_2 + 3)^2)}{2}
$$
那么我們對其求偏導(dǎo)也很簡單
$$
\frac{\partial out}{\partial x_1} = 3(x_1 + 3)|{x_1=2}=15 \quad \frac{\partial out}{\partial x_2} = 3(x_2 + 3)|{x_2=3} = 18
$$
這樣依靠簡單的微積分知識我們就能夠算出他們的結(jié)果军熏,運行一下程序,確保結(jié)果一致刃唐,ok羞迷。

Paste_Image.png

下面我們研究一下如何能夠?qū)Ψ菢肆康那闆r下使用backward,下面開始做實驗(瞎試)画饥。

m = v(t.FloatTensor([[2, 3]]), requires_grad=True)
n = v(t.zeros(1, 2))
n[0, 0] = m[0, 0] ** 2
n[0, 1] = m[0, 1] ** 3

首先我們定義好輸入$m = (x_1, x_2) = (2, 3)$衔瓮,然后我們做的操作就是$n = (x_1^2, x_2^3)$,這樣我們就定義好了一個向量輸出抖甘,結(jié)果第一項只和$x_1$有關(guān)热鞍,結(jié)果第二項只和$x_2$有關(guān),那么求解這個梯度,我們知道$\frac{\partial n_1}{\partial x_1} = 2 x_1 = 4, \frac{\partial n_2}{\partial x_2} = 3 x_2^2 = 27$ 薇宠,下面我們開始探究如何能夠讓他調(diào)用backward偷办。

第一想法就是里面這個參數(shù)是要求梯度的對象,我們這樣調(diào)用n.backward(m.data)澄港,有有報錯誒椒涯,是不是成功了,我真的是個天才回梧,這么難的東西都能想到废岂,等等,我好想看到了一個很神奇的結(jié)果狱意。

Paste_Image.png

這是什么鬼湖苞,這跟說好的結(jié)果不一樣啊,我們想要的結(jié)果是4和27,現(xiàn)在給我們的結(jié)果是8和81,為什么會出現(xiàn)這樣神奇的結(jié)果呢详囤,想不通啊财骨。我們看看我們傳入的參數(shù)是m.data,這是一個(2, 3)的向量藏姐,我們希望得到的梯度是(4, 27)隆箩,好像($42=8, 273=81$),我的內(nèi)心毫無波動包各,甚至有點想笑摘仅,似乎backward將我傳入的參數(shù)m.data乘上了得到的梯度,既然要乘上我傳入的參數(shù)问畅,那么我就給你傳入1,這樣總能得到我想要的結(jié)果了吧娃属,n.backward(t.FloatTensor([[1, 1]])),看看結(jié)果呢

backwardresult2.png

哇护姆,跟我們想要的結(jié)果一樣誒矾端,撒花,我們解決了一個大問題卵皂,就是這么簡單秩铆,扔進去一個1就可以了,這個問題也沒有那么難嘛灯变,哈哈哈殴玛。

似乎又有一點不對,如果這么簡單那么寫PyTorch的人為什么不把這一步直接集成進去添祸,那我們不就不會遇到這個問題了嘛滚粟。

Paste_Image.png

我們來試試另外一種情況

m = v(t.FloatTensor([[2, 3]]), requires_grad=True)
j = t.zeros(2 ,2)
k = v(t.zeros(1, 2))
m.grad.data.zero_()
k[0, 0] = m[0, 0] ** 2 + 3 * m[0 ,1]
k[0, 1] = m[0, 1] ** 2 + 2 * m[0, 0]

上面的代碼寫成數(shù)學表達式就是$m = (x_1=2, x_2=3), k = (x_1^2 + 3x_2, x_2^2+2x_1)$,么我們直接對k反向傳播k.backward(t.FloatTensor([[1, 1]])刃泌,結(jié)果是什么呢凡壤?

首先我們手動算一算結(jié)果是什么署尤。$\frac{\partial (x_1^2 + 3x_2)}{\partial x_1 } = 2x_1=4,\ \frac{\partial (x_1^2 + 3x_2)}{\partial x_2 } = 3,\ \frac{\partial (x_2^2 + 2x_1)}{\partial x_1} = 2,\ \frac{\partial (x_2^2 + 2x_1)}{\partial x_2} = 2x_2 = 6$,我們是希望能夠得到上面四個結(jié)果亚侠,這個時候你可能已經(jīng)開始懷疑了曹体,能夠得到這4個結(jié)果嗎?我們可以輸出結(jié)果來看看

backwardresult3.png

非常遺憾硝烂,我們只得到了兩個結(jié)果箕别,并且數(shù)值并不對,這個時候你就會疑惑了滞谢,到底是哪里出了問題呢究孕,為什么會得到這樣的結(jié)果呢?

經(jīng)過不斷地嘗試爹凹,我終于發(fā)現(xiàn)了其中的奧秘,k.backward(parameters)接受的參數(shù)parameters必須要和k的大小一模一樣镶殷,然后作為k的系數(shù)傳回去禾酱,什么意思呢,我們通過上面的例子來解釋這個問題你就知道了绘趋。

我們已經(jīng)知道我們得到的$k = (k_1, k_2)$颤陶,以及傳入的參數(shù)是1和1,那么是如何得到這6和9這兩個結(jié)果的呢陷遮?

其實第一個結(jié)果是通過$1 * \frac{d k_1}{d x_1} + 1 * \frac{d k_2}{d x_1} = 2 x_1 + 2 = 6$這樣得到的滓走,是不是有點理解這個操作是怎么完成的了,我們再來看看第二個結(jié)果帽馋,$ 1 * \frac{d k_1}{d x_2} + 1 * \frac{d k_2}{d x_2} = 3+2 x_2 = 9$搅方,這樣我們就得到了這兩個結(jié)果,原來我們傳入的參數(shù)是每次求導(dǎo)的一個系數(shù)绽族。

我們知道了這個操作具體是怎么完成的姨涡,我們就可以求求我們需要的這個jacobian矩陣了,非常簡單吧慢。

# jacobian
j = t.zeros(2 ,2)
k = v(t.zeros(1, 2))
m.grad.data.zero_()
k[0, 0] = m[0, 0] ** 2 + 3 * m[0 ,1]
k[0, 1] = m[0, 1] ** 2 + 2 * m[0, 0]
k.backward(t.FloatTensor([[1, 0]]), retain_variables=True)
j[:, 0] = m.grad.data
m.grad.data.zero_()
k.backward(t.FloatTensor([[0, 1]]))
j[:, 1] = m.grad.data
print('jacobian matrix is')
print(j)

我們可以得到如下結(jié)果

Paste_Image.png

這里我們要注意backward()里面另外的一個參數(shù)retain_variables=True涛漂,這個參數(shù)默認是False,也就是反向傳播之后這個計算圖的內(nèi)存會被釋放检诗,這樣就沒辦法進行第二次反向傳播了匈仗,所以我們需要設(shè)置為True,因為這里我們需要進行兩次反向傳播求得jacobian矩陣逢慌。

最后我們再舉一個矩陣乘法的例子試驗一下我們的結(jié)果

x = t.FloatTensor([2, 1]).view(1, 2)
x = v(x, requires_grad=True)
y = v(t.FloatTensor([[1, 2], [3, 4]]))

z = t.mm(x, y)
jacobian = t.zeros((2, 2))
z.backward(t.FloatTensor([[1, 0]]), retain_variables=True)  # dz1/dx1, dz2/dx1
jacobian[:, 0] = x.grad.data
x.grad.data.zero_()
z.backward(t.FloatTensor([[0, 1]]))  # dz1/dx2, dz2/dx2
jacobian[:, 1] = x.grad.data
print('=========jacobian========')
print('x')
print(x.data)
print('y')
print(y.data)
print('compute result')
print(z.data)
print('jacobian matrix is')
print(jacobian)

上面是代碼悠轩,仔細閱讀,作為一個小練習回顧一下本篇文章講的內(nèi)容涕癣,媽媽再也不用擔心我不會用backward了哗蜈。


本文代碼已經(jīng)上傳到了github

歡迎查看我的知乎專欄前标,深度煉丹

歡迎訪問我的博客

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市距潘,隨后出現(xiàn)的幾起案子炼列,更是在濱河造成了極大的恐慌,老刑警劉巖音比,帶你破解...
    沈念sama閱讀 219,366評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件俭尖,死亡現(xiàn)場離奇詭異,居然都是意外死亡洞翩,警方通過查閱死者的電腦和手機稽犁,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,521評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來骚亿,“玉大人已亥,你說我怎么就攤上這事±赐溃” “怎么了虑椎?”我有些...
    開封第一講書人閱讀 165,689評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長俱笛。 經(jīng)常有香客問我捆姜,道長,這世上最難降的妖魔是什么迎膜? 我笑而不...
    開封第一講書人閱讀 58,925評論 1 295
  • 正文 為了忘掉前任泥技,我火速辦了婚禮,結(jié)果婚禮上磕仅,老公的妹妹穿的比我還像新娘珊豹。我一直安慰自己,他們只是感情好宽涌,可當我...
    茶點故事閱讀 67,942評論 6 392
  • 文/花漫 我一把揭開白布平夜。 她就那樣靜靜地躺著,像睡著了一般卸亮。 火紅的嫁衣襯著肌膚如雪忽妒。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,727評論 1 305
  • 那天兼贸,我揣著相機與錄音段直,去河邊找鬼。 笑死溶诞,一個胖子當著我的面吹牛鸯檬,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播螺垢,決...
    沈念sama閱讀 40,447評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼喧务,長吁一口氣:“原來是場噩夢啊……” “哼赖歌!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起功茴,我...
    開封第一講書人閱讀 39,349評論 0 276
  • 序言:老撾萬榮一對情侶失蹤庐冯,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后坎穿,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體展父,經(jīng)...
    沈念sama閱讀 45,820評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,990評論 3 337
  • 正文 我和宋清朗相戀三年玲昧,在試婚紗的時候發(fā)現(xiàn)自己被綠了栖茉。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,127評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡孵延,死狀恐怖吕漂,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情尘应,我是刑警寧澤痰娱,帶...
    沈念sama閱讀 35,812評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站菩收,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏鲸睛。R本人自食惡果不足惜娜饵,卻給世界環(huán)境...
    茶點故事閱讀 41,471評論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望官辈。 院中可真熱鬧箱舞,春花似錦、人聲如沸拳亿。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,017評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽肺魁。三九已至电湘,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間鹅经,已是汗流浹背寂呛。 一陣腳步聲響...
    開封第一講書人閱讀 33,142評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留瘾晃,地道東北人贷痪。 一個月前我還...
    沈念sama閱讀 48,388評論 3 373
  • 正文 我出身青樓,卻偏偏與公主長得像蹦误,于是被迫代替她去往敵國和親劫拢。 傳聞我的和親對象是個殘疾皇子肉津,可洞房花燭夜當晚...
    茶點故事閱讀 45,066評論 2 355

推薦閱讀更多精彩內(nèi)容