本文從CSDN上轉(zhuǎn)移過來:
http://blog.csdn.net/mounty_fsc/article/details/51090114
在訓(xùn)練lenet的train_lenet.sh
中內(nèi)容為:
./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt
由此可知,訓(xùn)練網(wǎng)咯模型是由tools/caffe.cpp
生成的工具caffe
在模式train
下完成的淌友。
初始化過程總的來說,從main()
、train()
中創(chuàng)建Solver
跺讯,在Solver
中創(chuàng)建Net
,在Net
中創(chuàng)建Layer
.
1 程序入口
- 找到
caffe.cpp
的main
函數(shù)中殉农,通過GetBrewFunction(caffe::string(argv[1]))()
調(diào)用執(zhí)行train()
函數(shù)刀脏。 -
train中
,通過參數(shù)-examples/mnist/lenet_solver.prototxt
把solver
參數(shù)讀入solver_param
中。 - 隨后注冊并定義
solver
的指針(見第2節(jié))shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param))
```
- 調(diào)用
solver
的Solver()
方法超凳。多個(gè)GPU涉及到GPU間帶異步處理問題(見第3節(jié))if (gpus.size() > 1) { caffe::P2PSync<float> sync(solver, NULL, solver->param()); sync.run(gpus); } else { LOG(INFO) << "Starting Optimization"; solver->Solve(); }
2 Solver的創(chuàng)建
在1中愈污,Solver
的指針solver
是通過SolverRegistry::CreateSolver
創(chuàng)建的,CreateSolver
函數(shù)中值得注意帶是return registry[type](param)
// Get a solver using a SolverParameter.
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
const string& type = param.type();
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);
}
其中:
registry
是一個(gè)map<string,Creator>: typedef std::map<string, Creator> CreatorRegistry
其中Creator
是一個(gè)函數(shù)指針類型: typedef Solver<Dtype>* (*Creator)(const SolverParameter&)
registry[type]
為一個(gè)函數(shù)指針變量轮傍,在Lenet5
中暂雹,此處具體的值為 caffe::Creator_SGDSolver<float>(caffe::SolverParameter const&)
其中Creator_SGDSolver
在以下宏中定義,
REGISTER_SOLVER_CLASS(SGD)
該宏完全展開得到的內(nèi)容為:
template <typename Dtype> \
Solver<Dtype>* Creator_SGDSolver( \
const SolverParameter& param) \
{ \
return new SGDSolver<Dtype>(param); \
} \
static SolverRegisterer<float> g_creator_f_SGD("SGD", Creator_SGDSolver<float>); \
static SolverRegisterer<double> g_creator_d_SGD("SGD", Creator_SGDSolver<double>)
從上可以看出创夜,registry[type](param)
中實(shí)際上調(diào)用了SGDSolver
帶構(gòu)造方法杭跪,事實(shí)上,網(wǎng)絡(luò)是在SGDSolver
的構(gòu)造方法中初始化的。
SGDSolver
的定義如下:
template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) { PreSolve(); }
......
SGDSolver
繼承與Solver<Dtype>
涧尿,因而new SGDSolver<Dtype>(param)
將執(zhí)行Solver<Dtype>
的構(gòu)造函數(shù)系奉,然后調(diào)用自身構(gòu)造函數(shù)。整個(gè)網(wǎng)絡(luò)帶初始化即在這里面完成(詳見本系列博文(三))现斋。
3 Solver::Solve()函數(shù)
在這個(gè)函數(shù)里面喜最,程序執(zhí)行完網(wǎng)絡(luò)的完整訓(xùn)練過程。
核心代碼如下:
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Step(param_.max_iter() - iter_);
//..
Snapshot();
//..
// some additional display
// ...
}
說明:
- 值得關(guān)注的代碼是
Step()
庄蹋,在該函數(shù)中瞬内,值得了param_.max_iter()
輪迭代(10000) - 在Snapshot()中序列化model到文件
4 Solver::Step()函數(shù)
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
//10000輪迭代
while (iter_ < stop_iter) {
// 每隔500輪進(jìn)行一次測試
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
// 測試網(wǎng)絡(luò),實(shí)際是執(zhí)行前向傳播計(jì)算loss
TestAll();
}
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
// 執(zhí)行反向傳播限书,前向計(jì)算損失loss虫蝶,并計(jì)算loss關(guān)于權(quán)值的偏導(dǎo)
loss += net_->ForwardBackward(bottom_vec);
}
// 平滑loss,計(jì)算結(jié)果用于輸出調(diào)試等
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
// 通過反向傳播計(jì)算的偏導(dǎo)更新權(quán)值
ApplyUpdate();
}
}
4.1 Solver::TestAll()函數(shù)
在TestAll()
中倦西,調(diào)用Test(test_net_id)
對每個(gè)測試網(wǎng)絡(luò)test_net(不是訓(xùn)練網(wǎng)絡(luò)train_net)進(jìn)行測試能真。在Lenet中,只有一個(gè)測試網(wǎng)絡(luò)扰柠,所以只調(diào)用一次Test(0)
進(jìn)行測試粉铐。
Test()函數(shù)里面做了兩件事:
- 前向計(jì)算網(wǎng)絡(luò),得到網(wǎng)絡(luò)損失卤档,見 (Caffe蝙泼,LeNet)前向計(jì)算(五)
- 通過測試網(wǎng)絡(luò)的第11層accuracy層,與第12層loss層結(jié)果統(tǒng)計(jì)accuracy與loss信息劝枣。
4.2 Net::ForwardBackward()函數(shù)
Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
Dtype loss;
Forward(bottom, &loss);
Backward();
return loss;
}
說明:
- 前向計(jì)算汤踏。計(jì)算網(wǎng)絡(luò)損失loss,參考 (Caffe舔腾,LeNet)前向計(jì)算(五)
- 反向傳播溪胶。計(jì)算loss關(guān)于網(wǎng)絡(luò)權(quán)值的偏導(dǎo),參考 (Caffe稳诚,LeNet)反向傳播(六)
4.3 Solver::ApplyUpdate()函數(shù)
根據(jù)反向傳播階段計(jì)算的loss關(guān)于網(wǎng)絡(luò)權(quán)值的偏導(dǎo)哗脖,使用配置的學(xué)習(xí)策略,更新網(wǎng)絡(luò)權(quán)值從而完成本輪學(xué)習(xí)采桃。詳見 (Caffe懒熙,LeNet)權(quán)值更新(七)
5 訓(xùn)練完畢
至此,網(wǎng)絡(luò)訓(xùn)練優(yōu)化完成普办。在第3部分solve()函數(shù)中工扎,最后對訓(xùn)練網(wǎng)絡(luò)與測試網(wǎng)絡(luò)再執(zhí)行一輪額外的前行計(jì)算求得loss,以進(jìn)行測試衔蹲。