訓(xùn)練
- 在訓(xùn)練之前犀填,可以先定一個(gè)train_one_epoch()函數(shù)用于進(jìn)行一個(gè)epoch的訓(xùn)練蠢壹。這個(gè)函數(shù)包括使用train_loader中的每一個(gè)batch進(jìn)行訓(xùn)練的訓(xùn)練部分;
def train_one_epoch(model, train_loader, criterion, optimizer, device):
model.train() # 切換模型到訓(xùn)練模式
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
# 前向傳播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 后向傳播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 計(jì)算batch內(nèi)損失
running_loss += loss.item() * inputs.size(0)
# 計(jì)算epoch內(nèi)損失
epoch_loss = running_loss / len(train_loader.dataset)
return epoch_loss
enumerate()
注:有時(shí)候在dataloader外面經(jīng)常會(huì)套一個(gè)enumerate()函數(shù)九巡,enumerate()函數(shù)用于在遍歷可迭代對(duì)象時(shí)图贸,同時(shí)獲得元素的索引和值。它的使用并不是強(qiáng)制性的,取決于是否需要跟蹤當(dāng)前批次的索引求妹。如果不需要索引乏盐,僅僅需要遍歷數(shù)據(jù)佳窑,那么可以直接迭代DataLoader而不使用enumerate()
舉例:
for batch_idx, batch_data in enumerate(train_loader):
# 將數(shù)據(jù)移動(dòng)到GPU
inputs, labels = batch_data
inputs, labels = inputs.to(device), labels.to(device)
# 前向傳播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 后向傳播
optimizer.zero_grad() # 清零所有參數(shù)的梯度
loss.backward() # 計(jì)算梯度
optimizer.step() # 更新參數(shù)
# 使用batch_idx
if batch_idx % 10 == 0: # 每10個(gè)批次打印一次損失
print(f'Batch [{batch_idx}], Loss: {loss.item():.4f}')
驗(yàn)證
- 如果有驗(yàn)證集制恍,可以編寫validate_one_epoch()函數(shù)用于實(shí)現(xiàn)對(duì)驗(yàn)證集中的每個(gè)批次進(jìn)行驗(yàn)證的驗(yàn)證部分
# 定義驗(yàn)證函數(shù)
def validate_one_epoch(model, valid_loader, criterion, device):
model.eval() # 切換到評(píng)估模式
running_loss = 0.0
# 在驗(yàn)證過(guò)程中不需要計(jì)算梯度
with torch.no_grad():
for inputs, labels in valid_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0) # 計(jì)算平均損失
epoch_loss = running_loss / len(valid_loader.dataset)
return epoch_loss
在每個(gè)epoch中進(jìn)行訓(xùn)練+驗(yàn)證
num_epochs = 10
for epoch in range(num_epochs):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
valid_loss = validate_one_epoch(model, valid_loader, criterion, device)
print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}')
測(cè)試(推理)
使用訓(xùn)練好的模型進(jìn)行推理,其實(shí)validation部分就是推理神凑,因此代碼和validate_one_epoch比較類似
# 設(shè)置模型為評(píng)估模式
model.eval()
# 進(jìn)行推理
with torch.no_grad(): # 在推理過(guò)程中不需要計(jì)算梯度
outputs = model(new_inputs)
# 輸出結(jié)果
print(outputs)