在基于PyTorch實(shí)現(xiàn)一個(gè)算法時(shí)赃蛛,通常會(huì)將代碼分成多個(gè)模塊忧勿,每個(gè)模塊單獨(dú)放在一個(gè)Python腳本中沈跨。這種做法可以提高代碼的可讀性秽五、可維護(hù)性和重用性。包括:
模型定義腳本 (model.py)
- 包含對(duì)神經(jīng)網(wǎng)絡(luò)模型類(lèi)的定義菊卷,如果模型比較復(fù)雜煤率,可以先定義每個(gè)小的layer/block嗜闻,然后類(lèi)套類(lèi)
數(shù)據(jù)處理腳本 (data.py)
- 數(shù)據(jù)加載和預(yù)處理的相關(guān)代碼秘车,包含定義dataset類(lèi)典勇,生成各種data_loader等等
訓(xùn)練和驗(yàn)證腳本 (train.py)
- 訓(xùn)練和驗(yàn)證的相關(guān)代碼,比如train_one_epoch()叮趴,validate_one_epoch()等
推理腳本 (inference.py)
- 包含使用訓(xùn)練好的模型進(jìn)行推理的代碼割笙,例如:
import torch
def infer(model, inputs, device):
model.eval()
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(inputs)
return outputs
工具文件 (utils.py)
這個(gè)文件可以包含一些輔助函數(shù)。例如保存和加載模型:
import torch
# 保存模型
def save_model(model, path='model.pth'):
torch.save(model.state_dict(), path)
# 加載模型
def load_model(model, path='model.pth'):
model.load_state_dict(torch.load(path))
return model
配置文件 (config.py)
- 這個(gè)文件可以包含一些配置參數(shù)眯亦。例如:
batch_size = 32
learning_rate = 0.01
num_epochs = 10
主程序腳本 (main.py 或 run.py)
- 負(fù)責(zé)調(diào)用其他模塊伤溉,進(jìn)行訓(xùn)練、驗(yàn)證和推理妻率。例如:
# 導(dǎo)入其他庫(kù)
import torch
import torch.nn as nn
import torch.optim as optim
# 導(dǎo)入自己寫(xiě)的文件的庫(kù)
from config import batch_size, learning_rate, num_epochs, model_save_path
from data import get_data_loaders
from model import SimpleModel
from train import train_one_epoch, validate
from inference import infer
from utils import save_model, load_model
def main():
# 指定設(shè)備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 實(shí)例化模型
model = SimpleModel().to(device)
# 定義損失函數(shù)
criterion = nn.MSELoss()
# 定義優(yōu)化器
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# 定義數(shù)據(jù)集
train_loader, valid_loader = get_data_loaders(batch_size)
# 訓(xùn)練模型
for epoch in range(num_epochs):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
valid_loss = validate(model, valid_loader, criterion, device)
print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}')
# 保存模型
save_model(model, model_save_path)
# 進(jìn)行推理
model = load_model(SimpleModel().to(device), model_save_path, device)
new_inputs = torch.randn(10, 10)
outputs = infer(model, new_inputs, device)
print("Inference results:")
print(outputs)
if __name__ == "__main__":
main()