Keras為我們提供了很多已經(jīng)定義好的網(wǎng)絡(luò)准脂,比如Embdding層邪媳,LSTM層捐顷,GRU等。但是在有些情況下雨效,這些預(yù)先定義好的網(wǎng)絡(luò)層并不能很好的滿足我們的需求迅涮,這個(gè)時(shí)候我們就需要自定義網(wǎng)絡(luò)層。
當(dāng)然徽龟,通過(guò)線上叮姑,我們可以很方便的查閱到大把資料關(guān)于使用keras定義自己的網(wǎng)絡(luò)層,很多blog都直入主題枕赵,直接告訴我們自定義layer需要涉及到三個(gè)方法---------build()半夷,call()凉唐,以及compute_output_shape(),處理好這幾個(gè)方法朱盐,我們便可以實(shí)現(xiàn)我們所需要的功能。但是知其然也要知其所以然菠隆,因而今天我們通過(guò)閱讀keras/engine/base_layer中的Layer類兵琳,來(lái)更好的理解整個(gè)網(wǎng)絡(luò)層的運(yùn)行過(guò)程。
上面我將Layer類中的一些關(guān)鍵方法貼出來(lái)骇径,為了更加直觀理解躯肌,方法的具體內(nèi)容都刪了,后面具體分析的時(shí)候破衔,在進(jìn)行補(bǔ)充羡榴。
首先我們可以看到,在Layer類中有兩個(gè)特殊的方法运敢,__init__()和__call__()校仑。
__init__()是構(gòu)造方法忠售,當(dāng)我們建立類對(duì)象時(shí),首先調(diào)用該方法初始化類對(duì)象迄沫。
__call__()是可調(diào)用方法稻扬,一旦實(shí)現(xiàn)該方法,我們的類對(duì)象在某些行為上可以表現(xiàn)的和函數(shù)一樣羊瘩√┘眩可以直接通過(guò)類對(duì)象object()進(jìn)行調(diào)用。下面舉個(gè)例子尘吗。
上面我們定義了一個(gè)類逝她,我們可以發(fā)現(xiàn)可以直接通過(guò)類對(duì)象()來(lái)調(diào)用__call__方法〔谴罚【? obj()等價(jià)于obj.__call__()? 】
這里我們也可以明白為什么平時(shí)可以直接使用例如? LSTM(32)(input)這種形式來(lái)添加網(wǎng)絡(luò)層黔宛。其實(shí)這種形式本質(zhì)是
上面的實(shí)例,我們也可以知道擒贸,在layer中__call__方法的參數(shù)是input臀晃,返回值是output。那么__call__方法究竟做了什么介劫?下面貼關(guān)鍵源碼(方便理解整個(gè)流程徽惋,貼完整的不易理解)感興趣的可以對(duì)照源碼理解。
OK!? 觀察上述代碼座韵,我們可以知道在__call__()方法中有幾個(gè)關(guān)鍵操作险绘,調(diào)用build(),調(diào)用call()誉碴,調(diào)用compute_output_shape()隆圆,最后再利用node將該層和上一層鏈接起來(lái)(如何鏈接可以不用關(guān)心)。
emmmmmmm翔烁,到這里渺氧,其實(shí)就是整個(gè)網(wǎng)絡(luò)層的運(yùn)行流程了。大家看懂了就可以撤了蹬屹。
(somebody :"侣背。。慨默。贩耐。。厦取。潮太。what Fuck! 你這講的都是啥,我還云里霧里呢!")
好吧铡买,為了不讓網(wǎng)友罵更鲁,我接著將build()等幾個(gè)方法具體分析。
build():我們知道奇钞,當(dāng)我們定義網(wǎng)絡(luò)層的時(shí)候澡为,需要用到一些張量(tensor)來(lái)對(duì)我們的輸入進(jìn)行操作。比如權(quán)重信息Weights,偏差Biases景埃。其實(shí)一個(gè)網(wǎng)絡(luò)本身就可以理解為這些張量的集合媒至。keras是如何在我們給定input以及output_dim的情況下定義這些張量的呢?這里主要就是build()方法的功勞了谷徙。build函數(shù)就是為該網(wǎng)絡(luò)定義一層相應(yīng)的張量的集合拒啰。在Layer類中有兩個(gè)成員變量,分別是trainable_weights和non_trainable_weights完慧,分別是指可以訓(xùn)練的參數(shù)的集合和不可訓(xùn)練的參數(shù)的集合谋旦。這兩個(gè)參數(shù)都是list。在build中建立的張量通過(guò)add_weight()方法加入到上面兩個(gè)張量集合中骗随,進(jìn)而建立網(wǎng)絡(luò)層。需要注意的是赴叹,一個(gè)網(wǎng)絡(luò)層的參數(shù)是固定的鸿染,我們不能重復(fù)添加,因此乞巧,build()方法最多只能調(diào)用一次涨椒。如何保證每個(gè)layer的build()最多調(diào)用一次?绽媒?蚕冬?這是通過(guò)self._built變量來(lái)控制的。如果built變量為True是辕,那么build()方法將不再會(huì)被調(diào)用囤热,否則build()才能被調(diào)用。在調(diào)用之后built會(huì)被賦值為True获三,防止以后build()被重復(fù)調(diào)用旁蔼。這在__call__()方法中有體現(xiàn)。所以我們?nèi)绾螞]有重新寫__call__()方法疙教,那么我們不用擔(dān)心build()方法會(huì)被多次調(diào)用棺聊。但是如果重新寫了__call__()方法,一定要注意在build()調(diào)用之后贞谓,將built置為True限佩。【TIP:build只接受input一個(gè)參數(shù),所以如果需要用到output_shape祟同,可以在__init__()中將output_shape賦值給一個(gè)成員變量作喘,這樣就可以在build中直接使用output_shape的值了】。舉個(gè)簡(jiǎn)單例子耐亏,以output=tanh(X*W+B),我們首先定義build()函數(shù)徊都。這里用到的參數(shù)分別是W和B。假設(shè)輸出的大小為output_dim广辰,且已經(jīng)在__init__()中已經(jīng)初始化了暇矫。
call(): 該方法是整個(gè)網(wǎng)絡(luò)層的邏輯輸出。通過(guò)build()择吊,我們已經(jīng)有了網(wǎng)絡(luò)層的權(quán)重等信息李根,接下來(lái)便是通過(guò)input以及這些權(quán)重張量(W,B)等來(lái)獲得輸出了。如何得到output就要根據(jù)大家需要的功能來(lái)說(shuō)了几睛。//////該函數(shù)返回值是output房轿,__call__()方法也是通過(guò)調(diào)用call()來(lái)獲得輸出output。
comput_output_shape():返回輸出的形狀所森,便于keras搭建下一層網(wǎng)絡(luò)時(shí)囱持,可以自行推導(dǎo)出輸入的形狀
好了,以上便大功告成了【第一次寫blog焕济,希望能幫助大家更好理解keras網(wǎng)絡(luò)層的整個(gè)控制流程】