上一篇說(shuō)要嘗試一下用 se_ResNeXt 來(lái)給 WS-DAN 網(wǎng)絡(luò)提取特征,在此之前需要先搞懂 ResNeXt 的原理戈锻,而 ResNeXt 則是在 ResNet 基礎(chǔ)上的改進(jìn),所以繞了一大圈,還得從 ResNet 開(kāi)始。說(shuō)來(lái)慚愧佃迄,之前只是用過(guò) ResNet 來(lái)做分類(lèi)任務(wù),論文還真沒(méi)有仔細(xì)讀過(guò)竿音,正好趁這個(gè)機(jī)會(huì)讀一讀這篇“神作”和屎。
論文地址: https://arxiv.org/pdf/1512.03385.pdf
論文閱讀
其實(shí)論文的思想在今天看來(lái)是不難的拴驮,不過(guò)在當(dāng)時(shí) ResNet 提出的時(shí)候可是橫掃了各大分類(lèi)任務(wù)春瞬,這個(gè)網(wǎng)絡(luò)解決了隨著網(wǎng)絡(luò)的加深,分類(lèi)的準(zhǔn)確率不升反降的問(wèn)題套啤。通過(guò)一個(gè)名叫“殘差”的網(wǎng)絡(luò)結(jié)構(gòu)(如下圖所示)宽气,使作者可以只通過(guò)簡(jiǎn)單的網(wǎng)絡(luò)深度堆疊便可達(dá)到提升準(zhǔn)確率的目的。
殘差結(jié)構(gòu)的處理過(guò)程分成兩個(gè)部分潜沦,左邊的 與右邊的 萄涯,最后結(jié)果為兩者相加。其中右邊那根線不會(huì)對(duì) 做任何處理唆鸡,所以沒(méi)有可學(xué)習(xí)的參數(shù)涝影; 為網(wǎng)絡(luò)中負(fù)責(zé)學(xué)習(xí)特征的部分,把整個(gè)殘差結(jié)構(gòu)看做是一個(gè) 函數(shù)的話争占,則負(fù)責(zé)學(xué)習(xí)的部分可以表示為 燃逻,這個(gè)結(jié)構(gòu)學(xué)習(xí)的其實(shí)是輸出結(jié)果與輸入的差值序目,這也是殘差名字的由來(lái)。完整的 ResNet 網(wǎng)絡(luò)由多個(gè)上圖中所示的殘差結(jié)構(gòu)組成伯襟,每個(gè)結(jié)構(gòu)學(xué)習(xí)的都是輸出與輸入之間的差值猿涨,通過(guò)步步逼近,達(dá)到了比直接學(xué)習(xí)輸入好得多的效果姆怪。
文中殘差結(jié)構(gòu)的具體實(shí)現(xiàn)分為兩種叛赚,首先介紹 ResNet-18 與 ResNet-34 使用的殘差結(jié)構(gòu)稱(chēng)為 Basic Block,如下圖所示稽揭,圖中的結(jié)構(gòu)包含了兩個(gè)卷積操作用于提取特征俺附。
對(duì)應(yīng)到代碼中,這是 Pytorch 自帶的 ResNet 實(shí)現(xiàn)中的一部分溪掀,跟上圖對(duì)應(yīng)起來(lái)看更加好理解昙读,我個(gè)人比較喜歡論文與代碼結(jié)合起來(lái)看,因?yàn)槲页诵枰涝碇馀蚯牛惨廊绾稳ナ褂寐耄a更給我一種一目了然的感覺(jué):
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
另一種殘差結(jié)構(gòu)稱(chēng)為 Bottleneck,就是瓶頸的意思:
作者起名字真的很形象只嚣,網(wǎng)絡(luò)結(jié)構(gòu)也正如這瓶頸一樣沮稚,首先做一個(gè)降維,然后做卷積册舞,然后升維蕴掏,這樣做的好處是可以大大減少計(jì)算量,專(zhuān)門(mén)用于網(wǎng)絡(luò)層數(shù)較深的的網(wǎng)絡(luò)调鲸,ResNet-50 以上的網(wǎng)絡(luò)都有這種基礎(chǔ)結(jié)構(gòu)構(gòu)成(不同層級(jí)的輸入輸出維度可能會(huì)不一樣盛杰,但結(jié)構(gòu)類(lèi)似):
Pytorch 中的代碼,注意到上圖中為了減少計(jì)算量藐石,作者將 256 維的輸入縮小了 4 倍變?yōu)?64 進(jìn)入卷積即供,在升維時(shí)需要升到 256 維,對(duì)應(yīng)代碼中的 expansion 參數(shù):
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
由上面介紹的基本結(jié)構(gòu)再加上池化以及全連接層于微,就構(gòu)成了各種完整的網(wǎng)絡(luò):
圖中的網(wǎng)絡(luò)在 Pytorch 中都已經(jīng)集成進(jìn)去了逗嫡,而且都是預(yù)訓(xùn)練好的,我們可以在預(yù)訓(xùn)練好的模型上面訓(xùn)練自己的分類(lèi)器株依,大大減少我們的訓(xùn)練時(shí)間驱证。下面簡(jiǎn)單介紹一下如何使用 ResNet。
在 Pytorch 中使用 ResNet
Pytorch 是一個(gè)對(duì)初學(xué)者很友好的深度學(xué)習(xí)框架恋腕,入門(mén)的話非常推薦抹锄,官方提供了一小時(shí)入門(mén)教程:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html
在 Pytorch 中使用 ResNet 只需要 4 行代碼:
from torch import nn
# torchvision 專(zhuān)用于視覺(jué)方面
import torchvision
# pretrained :使用在 ImageNet 數(shù)據(jù)集上預(yù)訓(xùn)練的模型
model = torchvision.models.resnet18(pretrained=True)
# 修改模型的全連接層使其輸出為你需要類(lèi)型數(shù),這里是10
# 由于使用了預(yù)訓(xùn)練的模型 而預(yù)訓(xùn)練的模型輸出為1000類(lèi),所以要修改全連接層
# 若不使用預(yù)訓(xùn)練的模型可以直接在創(chuàng)建模型時(shí)添加參數(shù) num_classes=10 而不需要修改全連接層
model.fc = nn.Linear(model.fc.in_features, 10)
下面你就可以使用這個(gè)模型來(lái)做分類(lèi)了伙单,當(dāng)然到這里還沒(méi)在自己的數(shù)據(jù)集上進(jìn)行訓(xùn)練呆万,關(guān)于如何訓(xùn)練可以參考官方教程:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
如果對(duì)代碼以及源碼有疑問(wèn)的話可以在下面留言我們一起討論。
最后车份,求贊求關(guān)注谋减,歡迎關(guān)注我的微信公眾號(hào)[MachineLearning學(xué)習(xí)之路] ,深度學(xué)習(xí) & CV 方向的童鞋不要錯(cuò)過(guò)Iㄕ印出爹!