權(quán)重衰退
-
使用均方范數(shù)作為硬性限制
如下所示漫玄,其中l(wèi)是我們要優(yōu)化的損失函數(shù),w和b是兩個參數(shù)压彭。w為權(quán)重睦优,b為偏移,但這個優(yōu)化函數(shù)并不常用壮不,多使用下面的柔性限制汗盘。
-
使用均方范數(shù)作為柔性限制
以上可以通過拉格朗日乘子來證明,超參數(shù)控制了正則項的重要程度询一。
其中隐孽,表示的最優(yōu)解癌椿。
- 無作用,即不會影響w的取值菱阵,等價于之前的
- ,等價于之前的踢俄,所以,那么最優(yōu)解
- 如果我們想讓模型復(fù)雜度低一些,那我們將增加些以滿足要求送粱。
-
參數(shù)更新法則
如下所示,將畫黃線部分代入紅色表達式并進行化簡掂之,即可得到時間t更新參數(shù)對應(yīng)的表達式抗俄。
其中表示學(xué)習(xí)率,在上圖中的第二個公式世舰,后面部分(減去學(xué)習(xí)率*梯度)與我們之前講的梯度下降是一樣的动雹,只是我們現(xiàn)在在每次更新前,在前面那里多減了一個跟压,進行權(quán)重的縮小胰蝠。 -
總結(jié)
5.代碼實現(xiàn)
# 權(quán)重衰減是最廣泛使用的正則化的技術(shù)之一 %matplotlib inline import math import torch from torch import nn from d2l import torch as d2l # 1. 生成一些數(shù)據(jù) n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5 # 隨機生成權(quán)重,以及將偏差設(shè)成為0.05 true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05 # 使用synthetic_data生成人工數(shù)據(jù)集震蒋,以及使用load_array加載內(nèi)存數(shù)據(jù) train_data = d2l.synthetic_data(true_w, true_b, n_train) train_iter = d2l.load_array(train_data, batch_size) test_data = d2l.synthetic_data(true_w, true_b, n_test) test_iter = d2l.load_array(test_data, batch_size, is_train=False) # 2.初始化模型參數(shù) def init_params(): # 根據(jù)圖片要求進行生成 w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True) b = torch.zeros(1, requires_grad=True) return [w, b] # 3. 定義L2范數(shù)懲罰(對照公式),也是本次的核心茸塞,注意我們在該函數(shù)中沒有將lambda放在里面 def L2_penalty(w): return torch.sum(w.pow(2)) / 2 # 拓展:我們也可以用L1 penalty(w) def L1_penalty(w): return torch.sum(torch.abs(w)) # 4. 定義訓(xùn)練代碼實現(xiàn) # lambda為超級參數(shù) def train(lambd): w, b = init_params() net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss num_epochs, lr = 100, 0.03 animator = d2l.Animator(xlabel='epochs', ylabel='loss',yscale='log', xlim=[5, num_epochs], legend=['train', 'test']) for epoch in range(num_epochs): for X, y in train_iter: # with torch.enable_grad(): # 增加L2范數(shù)懲罰項,廣播機制使l2_penalty(w)成為一個長度為`batch_size`的向量查剖。 # 以下表達式對應(yīng)柔性限制的核心 l = loss(net(X), y) + lambd * L1_penalty(w) l.sum().backward() d2l.sgd([w, b], lr, batch_size) if(epoch + 1) % 5 ==0: animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss), d2l.evaluate_loss(net, test_iter, loss))) print('w的L2范數(shù)是:', torch.norm(w).item()) # 忽略正則化直接訓(xùn)練 train(lambd=0) # 出現(xiàn)嚴(yán)重的過擬合 # 嘗試改變lambda的值 train(lambd=3)
-
對應(yīng)運行結(jié)果
當(dāng)我們將設(shè)為0時钾虐,得到的結(jié)果,如第1張圖所示笋庄,很明顯發(fā)生了嚴(yán)重的過濾盒效扫,當(dāng)我們將設(shè)為3時,得到的效果還不錯直砂,具體哪個參數(shù)最優(yōu)菌仁,則需要自己去調(diào)參。