獻(xiàn)給瑩瑩
1. VGG Net網(wǎng)絡(luò)結(jié)構(gòu)
VGG是十分經(jīng)典的網(wǎng)絡(luò)了滚停,沒什么好說的蓄拣。網(wǎng)絡(luò)結(jié)構(gòu)如下
注解:
- LRN層
https://blog.csdn.net/hduxiejun/article/details/70570086 - VGG16
是有1*1的卷積核的
2.搭建過程
- 1.加載必要及準(zhǔn)備工作
import torch
import torch.nn as nn
cfg = {
'VGG11': [64, 'M', 128, 'M', 256,'M', 512, 'M', 512,'M'],
'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
# 不同的vgg結(jié)構(gòu)闷愤,這樣寫可以有效節(jié)約代碼空間眶痰。
- 2.構(gòu)建模型
class VGG(nn.Module):
#nn.Module是一個(gè)特殊的nn模塊烁峭,加載nn.Module,這是為了繼承父類
def __init__(self, vgg_name):
super(VGG, self).__init__()
#super 加載父類中的__init__()函數(shù)
self.features = self._make_layers(cfg[vgg_name])
self.classifier=nn.Linear(512,10)
#該網(wǎng)絡(luò)輸入為Cifar10數(shù)據(jù)集艾恼,因此輸出為(512,1麸锉,1)
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
#這一步將out拉成out.size(0)的一維向量
out = self.classifier(out)
return out
def _make_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x, kernel_size=3,
padding=1,bias=False),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)]
in_channels = x
return nn.Sequential(*layers)
'''
nn.Sequential(*layers) 表示(只是舉例子)
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace)
)
'''
- 3.檢驗(yàn)?zāi)P?/li>
def t():
net = VGG('VGG19')
x = torch.randn(5,3,32,32)
y = net(x)
print(y.size())
if __name__ == "__main__":
t()
#如果輸出為(5,10),表示結(jié)果正確