ResNet34
網(wǎng)絡(luò)中有很多結(jié)構(gòu)相似的單元杨箭,共同點(diǎn)是有個(gè)跨層直連的shortcut舔糖,將一個(gè)跨層直連的單元稱為Residual block筝闹,通道數(shù)一樣奥邮。
將擁有多個(gè)Residual block單元的結(jié)構(gòu)稱為layer。
Residual block—子module
layer—函數(shù)
通過shortcut快捷結(jié)構(gòu)枫弟,反向傳播時(shí)信號(hào)可無(wú)衰減地傳遞邢享,因?yàn)榧由顚訉?dǎo)致的梯度變小的梯度消失問題得到緩解。
導(dǎo)入模塊
from torch import nn
import torch as t
from torch.nn import functional as F
定義子module
class ResidualBlock(nn.Module)
def __init__(self,inchannel,outchannel,stride=1,shortcut=None)
super(ResidualBlock,self).__init__()
self.left = nn.Sequential(
nn.Conv2d(inchannel,outchannel,3,stride,1,bias=False),
nn.BatchNorm2d(outchannel),
nn.ReLU(implace=True),
nn.Conv2d(outchannel,outchannel,3,1,1,bias=False),
nn.BatchNorm2d(outchannel) )
self.right = shortcut
def forward(self,x):
out = self.left(x)
residual = x if self.right is None else self.right(x)
out += residual
return F.relu(out)
定義ResNet淡诗,保存模型
class ResNet(nn.Module):
def __init__(self,num_classes=1000):
super(ResNet,self).__init()
self.pre = nn.Sequential(
nn.Conv2d(3,64,7,2,3,bias=False),
nn.BatchNorm2d(64),
nn.ReLU(implace=True),
nn.MaxPool2d(3,2,1) )
self.layer1 = self.__make_layer(64,128,3)
self.layer2 = self.__make_layer(128,256,4,stride=2)
self.layer3 = self.__make_layer(256,512,6,stride=2)
self.layer4 = self.__make_layer(512,512,3,stride=2)
self.fc = nn.Linear(512,num_classes)
def _make_layer(self,inchannel,outchannel,block_num,stride=1)
shortcut = nn.Sequential(
nn.Conv2d(inchannel,outchannel,1,stride,bias=False),
nn.BatchNorm2d(outchannel) )
layers = []
layers.append(ResidualBlock(inchannel,outchannel,stride,shortcut))
for i in range(1,block_num):
layers.append(ResidualBlock(outchannel,outchannel))
return nn.Sequential(*layers)
def forward(self,x):
x = self.pre(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = F.avg_pool2d(x,7)
x = x.view(x.size(0),-1)
return self.fc(x)
model = ResNet()
input = r.autograd.Variable(t,randn(1,3,224,224))
o = model(input)
深度學(xué)習(xí)的算法本質(zhì)是反向傳播求導(dǎo)數(shù)骇塘,autograd自動(dòng)提供微分,
Variable是其核心數(shù)據(jù)結(jié)構(gòu)韩容,封裝了tensor款违,記錄對(duì)其操作記錄構(gòu)建計(jì)算圖。