如需了解示例完整代碼及其后續(xù)內(nèi)容請訪問: https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/
緣由
自己在嘗試了官方的代碼后就想提高訓(xùn)練的精度就想到了調(diào)整學習率玲销,但固定的學習率肯定不適合訓(xùn)練就嘗試了幾個更改學習率的方法输拇,但沒想到居然更差摘符!可能有幾個學習率沒怎么嘗試吧!
更新方法
直接修改optimizer中的lr參數(shù)贤斜;
- 定義一個簡單的神經(jīng)網(wǎng)絡(luò)模型:y=Wx+b
import torch
import matplotlib.pyplot as plt
%matplotlib inline
from torch.optim import *
import torch.nn as nn
class net(nn.Module):
def __init__(self):
super(net,self).__init__()
self.fc = nn.Linear(1,10)
def forward(self,x):
return self.fc(x)
- 直接更改lr的值
model = net()
LR = 0.01
optimizer = Adam(model.parameters(),lr = LR)
lr_list = []
for epoch in range(100):
if epoch % 5 == 0:
for p in optimizer.param_groups:
p['lr'] *= 0.9
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')
關(guān)鍵是如下兩行能達到手動階梯式更改,自己也可按需求來更改變換函數(shù)
for p in optimizer.param_groups:
p['lr'] *= 0.9
利用lr_scheduler()提供的幾種衰減函數(shù)
- torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
參數(shù) | 含義 |
---|---|
lr_lambda | 會接收到一個int參數(shù):epoch逛裤,然后根據(jù)epoch計算出對應(yīng)的lr瘩绒。如果設(shè)置多個lambda函數(shù)的話,會分別作用于Optimizer中的不同的params_group |
import numpy as np
lr_list = []
model = net()
LR = 0.01
optimizer = Adam(model.parameters(),lr = LR)
lambda1 = lambda epoch:np.sin(epoch) / epoch
scheduler = lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda1)
for epoch in range(100):
scheduler.step()
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')
- torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
參數(shù) | 含義 |
---|---|
T_max | 對應(yīng)1/2個cos周期所對應(yīng)的epoch數(shù)值 |
eta_min | 最小的lr值带族,默認為0 |
lr_list = []
model = net()
LR = 0.01
optimizer = Adam(model.parameters(),lr = LR)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max = 20)
for epoch in range(100):
scheduler.step()
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')
- torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
在發(fā)現(xiàn)loss不再降低或者acc不再提高之后锁荔,降低學習率。各參數(shù)意義如下:
參數(shù) | 含義 |
---|---|
mode | 'min'模式檢測metric是否不再減小蝙砌,'max'模式檢測metric是否不再增大阳堕; |
factor | 觸發(fā)條件后lr*=factor跋理; |
patience | 不再減小(或增大)的累計次數(shù)恬总; |
verbose | 觸發(fā)條件后print前普; |
threshold | 只關(guān)注超過閾值的顯著變化; |
threshold_mode | 有rel和abs兩種閾值計算模式壹堰,rel規(guī)則:max模式下如果超過best(1+threshold)為顯著拭卿,min模式下如果低于best(1-threshold)為顯著;abs規(guī)則:max模式下如果超過best+threshold為顯著贱纠,min模式下如果低于best-threshold為顯著峻厚; |
cooldown | 觸發(fā)一次條件后,等待一定epoch再進行檢測谆焊,避免lr下降過速惠桃; |
min_lr | 最小的允許lr; |
eps | 如果新舊lr之間的差異小與1e-8懊渡,則忽略此次更新刽射。 |
如需了解其它學習率更新方法請訪問: https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/
示例
使用的更新方法
代碼中可選的選項有:余弦方式(默認方式,其他兩種注釋了)剃执、e^-x的方式以及按loss是否不在降低來判斷的三種方式誓禁,其他就自己測試吧!
完整代碼
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from torch.optim import *
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#如需了解示例完整代碼及其后續(xù)內(nèi)容請訪問: [https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/](https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/)
如需了解示例完整代碼及其后續(xù)內(nèi)容請訪問: https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/