之前對(duì)于這一方面了解較少淆九,搭網(wǎng)絡(luò)就是直接一層一層堆砌,簡單粗暴。當(dāng)然這樣做是不對(duì)的炭庙,寫出來的代碼不僅難看饲窿,集成度不高,遷移困難焕蹄,而且還容易出錯(cuò)逾雄,并且對(duì)于一些情況,簡單一層一層堆網(wǎng)絡(luò)是解決不了的擦盾。因此嘲驾,了解一下pytorch container的相關(guān)內(nèi)容還是有必要的淌哟。
1. nn.Module
這個(gè)是最常用的container迹卢,所有其他網(wǎng)絡(luò)都是這個(gè)類的繼承。我們?cè)谧约憾x一個(gè)網(wǎng)絡(luò)或者層時(shí)徒仓,就需要繼承這個(gè)類腐碱。module允許以樹結(jié)構(gòu)進(jìn)行嵌入,一個(gè)module可以包含其他module掉弛,這個(gè)module就是原有module的submodule症见。
class MyModule(nn.Module):
def __init__(self):
self.conv1 = nn.Conv2d(16, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 32, 3, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
上面的例子中,conv1就是MyModule的submodule殃饿。當(dāng)對(duì)MyModule的實(shí)例進(jìn)行cuda轉(zhuǎn)換時(shí)谋作,conv1作為其submodule,也會(huì)被轉(zhuǎn)換為cuda數(shù)據(jù)格式乎芳,并且遵蚜,這一過程是遞歸進(jìn)行的。這很重要奈惑,需要認(rèn)真理解吭净。如果不是nn.Module的實(shí)例,就不會(huì)被加入到計(jì)算圖中肴甸,也不會(huì)被轉(zhuǎn)換為cuda格式寂殉。今天我就趟了一個(gè)這樣的坑,模型怎么都訓(xùn)不出來原在,最后發(fā)現(xiàn)計(jì)算圖里根本沒有層友扰。
-
add_module
除了上面的做法,也可以用add_module添加一個(gè)層到網(wǎng)絡(luò)里庶柿,這樣做的好處是可以給層命名村怪,這樣就可以直接通過層名來找到一個(gè)層了。
class MyModule(nn.Module):
def __init__(self):
self.add_module('conv1', nn.Conv2d(16, 32, 3, 1))
self.add_module('conv2', nn.Conv2d(32, 32, 3, 1))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
m = MyModule()
m1 = m.conv1
2. nn.Sequential
這個(gè)在網(wǎng)絡(luò)中的出現(xiàn)次數(shù)也比較頻繁澳泵,通過這個(gè)模塊实愚,可以將代碼寫得更密集一點(diǎn),可讀性也更強(qiáng)。sequential是一個(gè)時(shí)序模型腊敲,根據(jù)每個(gè)submodule傳入的順序?qū)懙接?jì)算圖里击喂,在forward的時(shí)候也會(huì)順序執(zhí)行。
conv1 = nn.Sequential(
nn.Conv2d(32, 64, 3, 1),
nn.BatchNorm2d(64),
nn.ReLU()
)
當(dāng)然碰辅,也可以傳一個(gè)OrderedDict來構(gòu)造網(wǎng)絡(luò)懂昂。
3. nn.ModuleList
如果不想讓module按照傳入順序執(zhí)行,就可以將它們寫成一個(gè)list没宾,用下標(biāo)來進(jìn)行索引凌彬。但是哪怕是直接在init函數(shù)里定義為成員變量,最后也不會(huì)被加到計(jì)算圖里循衰,真的是很心煩铲敛。我今天就跳了一個(gè)這樣的坑,正打算自己親自寫一個(gè)類時(shí)会钝,發(fā)現(xiàn)pytorch已經(jīng)幫我們實(shí)現(xiàn)好了伐蒋,就是ModuleList。官方大法好~~
ModuleList也是繼承了Module的一個(gè)子類迁酸,可以像python list一樣用下標(biāo)索引先鱼,可以使用append和extend方法,最重要的是奸鬓,也會(huì)被加到計(jì)算圖里焙畔,總之能用就是了。
MyModule = nn.ModuleList()
MyModule.append(nn.Conv2d(32, 64, 3, 1))
MyModule.append(nn.Conv2d(32, 64, 1, 1))
m = MyModule()
x1 = m[0](x)
x2 = m[1](x)