作者:geekboys
日期:2020-3-4
PyTorch模型定義的三要素
1.必須繼承nn.Module這個類躁绸,要讓PyTorch知道這個類是一個Module
2.在init(self)中設(shè)置好需要的"組件"(如conv,pooling,Linear,BatchNorm等)
3.最后,在forward(self,x)中定義好的“組件”進(jìn)行組裝猴仑,就像搭積木囚痴,把網(wǎng)絡(luò)結(jié)構(gòu)搭建出來膜廊,這樣一個模型就定義好了泳叠。
這里可以搭建一個簡單的模型來體現(xiàn)一下這種模型搭建的方法:
#一個簡單的模型
import torch
import torch.nn as nn
import torch.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()#實(shí)現(xiàn)父類的初始化
self.conv1=nn.Conv2d(3,6,5)#定義卷積層組件
self.pool1=nn.MaxPool2d(2,2)#定義池化層組件
self.conv2=nn.Conv2dn(6,16,5)
self.pool2=nn.MaxPool2d(2,2)
self.fc1=nn.Linear(16*5*5,120)#定義線性連接
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)
當(dāng)這些組件定義好之后迂卢,就可以定義forward()函數(shù)亮隙,用來搭建模型結(jié)構(gòu)途凫。
def forward(self,x):#x模型的輸入
x=self.pool1(F.relu(self.conv1(x)))
x=self.pool2(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5)#表示將x進(jìn)行reshape,為后面做為全連接層的輸入
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x
上面我們就成功的搭建了一個網(wǎng)絡(luò)是不是很方便溢吻,當(dāng)我們實(shí)例化一個模型net=Net()维费,然后把輸入inputs扔進(jìn)去,outputs=net(inputs)就可以得到輸出outputs.
在PyTorch模型定義中還會經(jīng)常的使用Sequetial這個組件
nn.Sequetial
torch.nn.Sequential其實(shí)就是Sequential容器促王,該容器將一系列操作按先后順序給包起來犀盟,方便重復(fù)使用。
所以總結(jié)起來蝇狼,PyTorch模型的定義過程為:
模型的定義就是先繼承阅畴,在構(gòu)建組件,最后組裝