如何確定全連接的參數(shù)
雖然目前使用全連接層的網(wǎng)絡(luò)模型越來(lái)越少翼虫,但是仍有部分網(wǎng)絡(luò)需要全連接層炼邀,但是如果通過(guò)CNN計(jì)算圖片的輸出尺寸可以說(shuō)有點(diǎn)復(fù)雜。現(xiàn)在就使用PyTorch自帶的功能來(lái)實(shí)現(xiàn)這個(gè)計(jì)算江滨,可以說(shuō)非常簡(jiǎn)單油坝。首先,我們先定義如下的網(wǎng)絡(luò):
class LinearDemo(nn.Module):
def __init__(self):
super(LinearDemo,self).__init__()
self.conv=nn.Sequential(
nn.Conv2d(3,96,kernel_size=11,stride=4),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2),
nn.Conv2d(96,256,kernel_size=5,padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2),
nn.Conv2d(256,384,kernel_size=3,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384,384,kernel_size=3,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384,256,kernel_size=3,padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2)
)
上面代碼中的基本組件這里就不多贅述了蹬挺,下面正常書(shū)寫(xiě)全連接層如下:
self.fc=nn.Sequential(
# nn.Linear(???,4096)
# )
其中???就是我們需要計(jì)算的參數(shù)值维贺,如果通過(guò)層的關(guān)系進(jìn)行計(jì)算則很容易出錯(cuò)。這里推薦使用PyTorch自帶的forward方法進(jìn)行推算巴帮。我們寫(xiě)forward方法如下:
def forward(self,x):
x=self.conv(x)
print(x.size())
這里我們可以在main方法中進(jìn)行調(diào)用后溯泣,就可以輸出該參數(shù)虐秋。main方法如下:
net=LinearDemo()
data_input=torch.randn(1,3,80,280)
print(data_input.size())
net(data_input)
這樣就將上面的參數(shù)輸出了。非常的簡(jiǎn)單