3.12 權(quán)重衰減
上一節(jié)中我們觀察了過擬合現(xiàn)象袋哼,即模型的訓(xùn)練誤差遠(yuǎn)小于它在測試集上的誤差殿漠。雖然增大訓(xùn)練數(shù)據(jù)集可能會減輕過擬合笨触,但是獲取額外的訓(xùn)練數(shù)據(jù)往往代價高昂超升。本節(jié)介紹應(yīng)對過擬合問題的常用方法:權(quán)重衰減(weight decay)拧略。
3.12.1 方法
3.12.2 高維線性回歸實(shí)驗(yàn)
%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
n_train, n_test, num_inputs = 20, 100, 200
true_w, true_b = torch.ones(num_inputs, 1) * 0.01, 0.05
features = torch.randn((n_train + n_test, num_inputs))
labels = torch.matmul(features, true_w) + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]
3.12.3 從零開始實(shí)現(xiàn)
3.12.3.1 初始化模型參數(shù)
首先逝嚎,定義隨機(jī)初始化模型參數(shù)的函數(shù)挤牛。該函數(shù)為每個參數(shù)都附上梯度。
def init_params():
w = torch.randn((num_inputs, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
return [w, b]
3.12.3.2 定義L2范數(shù)懲罰項(xiàng)
def l2_penalty(w):
return (w**2).sum() / 2
3.12.3.3 定義訓(xùn)練和測試
下面定義如何在訓(xùn)練數(shù)據(jù)集和測試數(shù)據(jù)集上分別訓(xùn)練和測試模型船庇。與前面幾節(jié)中不同的是,這里在計(jì)算最終的損失函數(shù)時添加了L2范數(shù)懲罰項(xiàng)侣监。
batch_size, num_epochs, lr = 1, 100, 0.003
net, loss = d2l.linreg, d2l.squared_loss
dataset = torch.utils.data.TensorDataset(train_features, train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)
def fit_and_plot(lambd):
w, b = init_params()
train_ls, test_ls = [], []
for _ in range(num_epochs):
for X, y in train_iter:
# 添加了L2范數(shù)懲罰項(xiàng)
l = loss(net(X, w, b), y) + lambd * l2_penalty(w)
l = l.sum()
if w.grad is not None:
w.grad.data.zero_()
b.grad.data.zero_()
l.backward()
d2l.sgd([w, b], lr, batch_size)
train_ls.append(loss(net(train_features, w, b), train_labels).mean().item())
test_ls.append(loss(net(test_features, w, b), test_labels).mean().item())
d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
range(1, num_epochs + 1), test_ls, ['train', 'test'])
print('L2 norm of w:', w.norm().item())
3.12.3.4 觀察過擬合
接下來鸭轮,讓我們訓(xùn)練并測試高維線性回歸模型。當(dāng)lambd
設(shè)為0時橄霉,我們沒有使用權(quán)重衰減窃爷。結(jié)果訓(xùn)練誤差遠(yuǎn)小于測試集上的誤差。這是典型的過擬合現(xiàn)象姓蜂。
fit_and_plot(lambd=0)
輸出:
L2 norm of w: 15.114808082580566
3.12.3.5 使用權(quán)重衰減
下面我們使用權(quán)重衰減按厘。可以看出钱慢,訓(xùn)練誤差雖然有所提高逮京,但測試集上的誤差有所下降。過擬合現(xiàn)象得到一定程度的緩解束莫。另外懒棉,權(quán)重參數(shù)的
L2范數(shù)比不使用權(quán)重衰減時的更小,此時的權(quán)重參數(shù)更接近0览绿。
fit_and_plot(lambd=3)
輸出:
L2 norm of w: 0.035220853984355927
3.12.4 簡潔實(shí)現(xiàn)
這里我們直接在構(gòu)造優(yōu)化器實(shí)例時通過weight_decay
參數(shù)來指定權(quán)重衰減超參數(shù)策严。默認(rèn)下,PyTorch會對權(quán)重和偏差同時衰減饿敲。我們可以分別對權(quán)重和偏差構(gòu)造優(yōu)化器實(shí)例妻导,從而只對權(quán)重衰減。
def fit_and_plot_pytorch(wd):
# 對權(quán)重參數(shù)衰減怀各。權(quán)重名稱一般是以weight結(jié)尾
net = nn.Linear(num_inputs, 1)
nn.init.normal_(net.weight, mean=0, std=1)
nn.init.normal_(net.bias, mean=0, std=1)
optimizer_w = torch.optim.SGD(params=[net.weight], lr=lr, weight_decay=wd) # 對權(quán)重參數(shù)衰減
optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr) # 不對偏差參數(shù)衰減
train_ls, test_ls = [], []
for _ in range(num_epochs):
for X, y in train_iter:
l = loss(net(X), y).mean()
optimizer_w.zero_grad()
optimizer_b.zero_grad()
l.backward()
# 對兩個optimizer實(shí)例分別調(diào)用step函數(shù)倔韭,從而分別更新權(quán)重和偏差
optimizer_w.step()
optimizer_b.step()
train_ls.append(loss(net(train_features), train_labels).mean().item())
test_ls.append(loss(net(test_features), test_labels).mean().item())
d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
range(1, num_epochs + 1), test_ls, ['train', 'test'])
print('L2 norm of w:', net.weight.data.norm().item())
與從零開始實(shí)現(xiàn)權(quán)重衰減的實(shí)驗(yàn)現(xiàn)象類似,使用權(quán)重衰減可以在一定程度上緩解過擬合問題瓢对。
fit_and_plot_pytorch(0)
輸出:
L2 norm of w: 12.86785888671875
fit_and_plot_pytorch(3)
輸出:
L2 norm of w: 0.09631537646055222
小結(jié)
- 正則化通過為模型損失函數(shù)添加懲罰項(xiàng)使學(xué)出的模型參數(shù)值較小狐肢,是應(yīng)對過擬合的常用手段。
- 權(quán)重衰減等價于L2范數(shù)正則化沥曹,通常會使學(xué)到的權(quán)重參數(shù)的元素較接近0份名。
- 權(quán)重衰減可以通過優(yōu)化器中的weight_decay超參數(shù)來指定。
- 可以定義多個優(yōu)化器實(shí)例對不同的模型參數(shù)使用不同的迭代方法妓美。
注:本節(jié)除了代碼之外與原書基本相同僵腺,原書傳送門