本文從CSDN上轉(zhuǎn)移過來:
http://blog.csdn.net/mounty_fsc/article/details/51588773
在Solver::ApplyUpdate()函數(shù)中宙刘,根據(jù)反向傳播階段計(jì)算的loss關(guān)于網(wǎng)絡(luò)權(quán)值的偏導(dǎo),使用配置的學(xué)習(xí)策略杠览,更新網(wǎng)絡(luò)權(quán)值從而完成本輪學(xué)習(xí)蝴罪。
1 模型優(yōu)化
1.1 損失函數(shù)
損失函數(shù)$L(W)$可由經(jīng)驗(yàn)損失加正則化項(xiàng)得到,如下埠偿,其中$X^{(i)}$為輸入樣本透罢;$f_W$為某樣本的損失函數(shù);$N$為mini-batch的樣本數(shù)量冠蒋;$r(W)$為以權(quán)值為$\lambda$的正則項(xiàng)羽圃。
$L(W) \approx \frac{1}{N} \sum_i^N f_W\left(X^{(i)}\right) + \lambda r(W)$
在caffe中,可以分為三個(gè)階段:
- 前向計(jì)算階段抖剿,這個(gè)階段計(jì)算$f_W$
- 反向傳播階段朽寞,這個(gè)階段計(jì)算$\nabla f_W$
- 權(quán)值更新階段胚吁,這個(gè)階段通過$\nabla f_W,\nabla r(W)$等計(jì)算$\Delta W$從而更新$W$
1.2 隨機(jī)梯度下降
在lenet中,solver的類型為SGD(Stochastic gradient descent)
SGD通過以下公式對權(quán)值進(jìn)行更新:
$W_{t+1} = W_t + V_{t+1}$
$V_{t+1} = \mu V_t - \alpha \nabla L(W_t)$
其中愁憔,$W_{t+1}$為第$t+1$輪的權(quán)值腕扶;$V_{t+1}$為第$t+1$輪的更新(也可以寫作$\Delta W_{t+1}$);$\mu$為上一輪更新的權(quán)重;$\alpha$為學(xué)習(xí)率吨掌;$\nabla L(W_t)$為loss對權(quán)值的求導(dǎo)
2 代碼分析
2.1 ApplyUpdate
void SGDSolver<Dtype>::ApplyUpdate() {
// 獲取該輪迭代的學(xué)習(xí)率(learning rate)
Dtype rate = GetLearningRate();
// 對每一層網(wǎng)絡(luò)的權(quán)值進(jìn)行更新
// 在lenet中半抱,只有`conv1`,`conv2`,`ip1`,`ip2`四層有參數(shù)
// 每層分別有參數(shù)與偏置參數(shù)兩項(xiàng)參數(shù)
// 因而`learnable_params_`的size為8.
for (int param_id = 0; param_id < this->net_->learnable_params().size();
++param_id) {
// 歸一化,iter_size為1不需要膜宋,因而lenet不需要
Normalize(param_id);
// 正則化
Regularize(param_id);
// 計(jì)算更新值\delta w
ComputeUpdateValue(param_id, rate);
}
// 更新權(quán)值
this->net_->Update();
}
說明:
-
lenet中學(xué)習(xí)參數(shù)設(shè)置可從
lenet_solver.prototxt
中查到# The base learning rate, momentum and the weight decay of the network. base_lr: 0.01 momentum: 0.9 weight_decay: 0.0005 # The learning rate policy lr_policy: "inv" gamma: 0.0001 power: 0.75
-
獲取學(xué)習(xí)率函數(shù)ApplyUpdate代碼此處不給出窿侈,查看注釋(以及caffe.proto)可知有如下學(xué)習(xí)率獲取策略。在Lenet中采用的是
inv
的策略秋茫,是一種沒一輪迭代學(xué)習(xí)率都改變的策略史简。// The learning rate decay policy. The currently implemented learning rate // policies are as follows: // - fixed: always return base_lr. // - step: return base_lr * gamma ^ (floor(iter / step)) // - exp: return base_lr * gamma ^ iter // - inv: return base_lr * (1 + gamma * iter) ^ (- power) // - multistep: similar to step but it allows non uniform steps defined by // stepvalue // - poly: the effective learning rate follows a polynomial decay, to be // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) // - sigmoid: the effective learning rate follows a sigmod decay // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) // // where base_lr, max_iter, gamma, step, stepvalue and power are defined // in the solver parameter protocol buffer, and iter is the current iteration.
2.2 Regularize
該函數(shù)實(shí)際執(zhí)行以下公式
$\nabla w_{ij}=decay*w_{ij}+\nabla w_{ij}$
代碼如下:
void SGDSolver<Dtype>::Regularize(int param_id) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
// local_decay = 0.0005 in lenet
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
...
if (regularization_type == "L2") {
// axpy means ax_plus_y. i.e., y = a*x + y
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
}
...
}
2.3 ComputeUpdateValue
該函數(shù)實(shí)際執(zhí)行以下公式
$\nabla w_{ij}=lr_rate\nabla w_{ij}+momentumw^{'}_{ij}$
$w^{'}$為上一輪的權(quán)值,注意結(jié)果保存的位置在cpu_diff
中即loss對參數(shù)的梯度中
代碼如下:
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
// momentum = 0.9 in lenet
Dtype momentum = this->param_.momentum();
// local_rate = lr_mult * global_rate
// lr_mult為該層學(xué)習(xí)率乘子肛著,在lenet_train_test.prototxt中設(shè)置
Dtype local_rate = rate * net_params_lr[param_id];
// Compute the update to history, then copy it to the parameter diff.
...
// axpby means ax_plus_by. i.e., y = ax + by
// 計(jì)算新的權(quán)值更新變化值 \delta w,結(jié)果保存在歷史權(quán)值變化中
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
// 從歷史權(quán)值變化中把變化值 \delta w 保存到歷史權(quán)值中diff中
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
...
}
2.4 net_->Update
實(shí)際執(zhí)行以下公式:
$w_{ij}=w_{ij}+(-1)*\nabla w_{ij}$
參考文獻(xiàn):