參考:https://github.com/datawhalechina/thorough-pytorch
1.Pytorch模型定義的方式
模型是深度學習中重要的組成部分挑格,是解決問題的關鍵所在倘核。
Module 類是 torch.nn 模塊里提供的一個模型構造類 (nn.Module)框冀,是所有神經(jīng)?網(wǎng)絡模塊的基類,我們可以繼承它來定義我們想要的模型杯瞻;
PyTorch模型定義應包括兩個主要部分:各個部分的初始化(_init_)沛慢;數(shù)據(jù)流向定義(forward)
基于nn.Module肾砂,我們可以通過Sequential,ModuleList和ModuleDict三種方式定義PyTorch模型扣溺。
1.1 Sequential
對應模塊為nn.Sequential()骇窍。
當模型的前向計算為簡單串聯(lián)各個層的計算時, Sequential 類可以通過更加簡單的方式定義模型锥余。它可以接收一個子模塊的有序字典(OrderedDict) 或者一系列子模塊作為參數(shù)來逐一添加 Module 的實例腹纳,?模型的前向計算就是將這些實例按添加的順序逐?計算。
sequential定義模型需要將模型的層按序排列起來驱犹,根據(jù)層名不同嘲恍,排列有兩種方式:直接排列和使用OrderedDict。
使用Sequential定義模型的好處在于簡單雄驹、易讀佃牛,同時使用Sequential定義的模型不需要再寫forward,因為順序已經(jīng)定義好了医舆。但使用Sequential也會使得模型定義喪失靈活性俘侠,比如需要在模型中間加入一個外部輸入時就不適合用Sequential的方式實現(xiàn)。使用時需根據(jù)實際需求加以選擇蔬将。
1.2 ModuleList
對應模塊為nn.ModuleList()爷速。
ModuleList 接收一個子模塊(或層,需屬于nn.Module類)的列表作為輸入霞怀,然后也可以類似List那樣進行append和extend操作惫东。同時,子模塊或層的權重也會自動添加到網(wǎng)絡中來毙石。
nn.ModuleList 并沒有定義一個網(wǎng)絡凿蒜,它只是將不同的模塊儲存在一起禁谦。ModuleList中元素的先后順序并不代表其在網(wǎng)絡中的真實位置順序胁黑,需要經(jīng)過forward函數(shù)指定各個層的先后順序后才算完成了模型的定義废封。具體實現(xiàn)時用for循環(huán)即可完成。
1.3 ModuleDict
對應模塊為nn.ModuleDict()丧蘸。
ModuleDict和ModuleList的作用類似漂洋,只是ModuleDict能夠更方便地為神經(jīng)網(wǎng)絡的層添加名稱。
1.4 總結
Sequential適用于快速驗證結果力喷,因為已經(jīng)明確了要用哪些層刽漂,直接寫一下就好了,不需要同時寫__init__和forward弟孟;
ModuleList和ModuleDict在某個完全相同的層需要重復出現(xiàn)多次時贝咙,非常方便實現(xiàn),可以”一行頂多行“拂募;
當我們需要之前層的信息的時候庭猩,比如 ResNets 中的殘差計算,當前層的結果需要和之前層中的結果進行融合陈症,一般使用 ModuleList/ModuleDict 比較方便蔼水。
2.利用模型塊快速搭建復雜網(wǎng)絡
用torch.nn中的層來定義Pytorch。這種定義方式易于理解录肯,在實際場景下不一定利于使用趴腋。當模型的深度非常大時候,使用Sequential定義模型結構需要向其中添加幾百行代碼论咏,使用起來不甚方便优炬。
對于大部分模型結構(比如ResNet、DenseNet等)厅贪,我們仔細觀察就會發(fā)現(xiàn)蠢护,雖然模型有很多層, 但是其中有很多重復出現(xiàn)的結構卦溢『啵考慮到每一層有其輸入和輸出,若干層串聯(lián)成的”模塊“也有其輸入和輸出单寂,如果我們能將這些重復出現(xiàn)的層定義為一個”模塊“贬芥,每次只需要向網(wǎng)絡中添加對應的模塊來構建模型,這樣將會極大便利模型構建的過程宣决。
以U-Net為例蘸劈,介紹如何構建模型塊,以及如何利用模型塊快速搭建復雜模型尊沸。
2.1?U-Net模型塊分析
模型從上到下分為若干層威沫,每層由左側和右側兩個模型塊組成贤惯,每側的模型塊與其上下模型塊之間有連接;同時位于同一層左右兩側的模型塊之間也有連接棒掠,稱為“Skip-connection”孵构。此外還有輸入和輸出處理等其他組成部分。由于模型的形狀非常像英文字母的“U”烟很,因此被命名為“U-Net”颈墅。
組成U-Net的模型塊主要有如下幾個部分:
1)每個子塊內部的兩次卷積(Double Convolution)
2)左側模型塊之間的下采樣連接,即最大池化(Max pooling)
3)右側模型塊之間的上采樣連接(Up sampling)
4)輸出層的處理
除模型塊外雾袱,還有模型塊之間的橫向連接恤筛,輸入和U-Net底部的連接等計算,這些單獨的操作可以通過forward函數(shù)來實現(xiàn)芹橡。
下面我們用PyTorch先實現(xiàn)上述的模型塊毒坛,然后再利用定義好的模型塊構建U-Net模型。
2.2?U-Net模型塊實現(xiàn)
在使用PyTorch實現(xiàn)U-Net模型時林说,先定義好模型塊煎殷,再定義模型塊之間的連接順序和計算方式。這里的基礎部件對應上一節(jié)分析的四個模型塊述么,根據(jù)功能我們將其命名為:DoubleConv, Down, Up, OutConv蝌数。