通過繼承nn.Module
類來自定義深度學(xué)習(xí)模型是 PyTorch 中常見的做法秉版。nn.Module
是所有神經(jīng)網(wǎng)絡(luò)模塊的基類,提供了許多有用的方法和屬性。自定義模型主要涉及以下幾個步驟:
- 定義模型類达椰,繼承 nn.Module
- 在 init 方法中定義模型的層
- 在 forward 方法中定義前向傳播的邏輯
- 實例化模型并使用
示例代碼
- MLP
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 示例用法
input_size = 784 # 例如28x28的圖像
hidden_size = 128
output_size = 10 # 10類分類
model = MLP(input_size, hidden_size, output_size)
print(model)
- CNN
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 示例用法
model = CNN()
print(model)
- Transformer
import torch
import torch.nn as nn
class TransformerModel(nn.Module):
def __init__(self, input_dim, model_dim, num_heads, num_layers, output_dim):
super(TransformerModel, self).__init__()
self.embedding = nn.Embedding(input_dim, model_dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(model_dim, output_dim)
def forward(self, src):
src = self.embedding(src) # [seq_len, batch_size, model_dim]
src = src.permute(1, 0, 2) # Transformer expects [batch_size, seq_len, model_dim]
output = self.transformer_encoder(src)
output = self.fc(output.mean(dim=1))
return output
# 示例用法
input_dim = 10000 # 詞匯表大小
model_dim = 512
num_heads = 8
num_layers = 6
output_dim = 10 # 10類分類
model = TransformerModel(input_dim, model_dim, num_heads, num_layers, output_dim)
print(model)