本文從CSDN上轉(zhuǎn)移過(guò)來(lái):
http://blog.csdn.net/mounty_fsc/article/details/51088173
1 Solver
1.1 簡(jiǎn)介
其對(duì)網(wǎng)絡(luò)進(jìn)行求解征峦,其作用有:
- 提供優(yōu)化日志支持、創(chuàng)建用于學(xué)習(xí)的訓(xùn)練網(wǎng)絡(luò)消请、創(chuàng)建用于評(píng)估的測(cè)試網(wǎng)絡(luò)
- 通過(guò)調(diào)用forward / backward迭代地優(yōu)化栏笆,更新權(quán)值
- 周期性地評(píng)估測(cè)試網(wǎng)絡(luò)
- 通過(guò)優(yōu)化了解model及solver的狀態(tài)
1.2 源代碼
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
class Solver {
public:
explicit Solver(const SolverParameter& param,
const Solver* root_solver = NULL);
explicit Solver(const string& param_file, const Solver* root_solver = NULL);
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
...
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
...
protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
...
SolverParameter param_;
int iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
vector<Callback*> callbacks_;
vector<Dtype> losses_;
Dtype smoothed_loss_;
// The root solver that holds root nets (actually containing shared layers)
// in data parallelism
const Solver* const root_solver_;
...
};
說(shuō)明:
shared_ptr<Net<Dtype>> net_
為訓(xùn)練網(wǎng)絡(luò)的指針,vector<shared_ptr<Net<Dtype>>> test_nets
為測(cè)試網(wǎng)絡(luò)的指針組臊泰,可見(jiàn)測(cè)試網(wǎng)絡(luò)可以有多個(gè)一般來(lái)說(shuō)訓(xùn)練網(wǎng)絡(luò)跟測(cè)試網(wǎng)絡(luò)在實(shí)現(xiàn)上會(huì)有區(qū)別蛉加,但是絕大部分網(wǎng)絡(luò)層是相同的。
不同的模型訓(xùn)練方法通過(guò)重載函數(shù)
ComputeUpdateValue( )
實(shí)現(xiàn)計(jì)算update參數(shù)的核心功能caffe.cpp中的
train( )
函數(shù)訓(xùn)練模型缸逃,在這里實(shí)例化一個(gè)Solver
對(duì)象针饥,初始化后調(diào)用了Solver
中的Solve( )
方法。而這個(gè)Solve( )函數(shù)主要就是在迭代運(yùn)行下面這兩個(gè)函數(shù)需频。ComputeUpdateValue();
net_->Update();
1.3 Solver的方法
- Stochastic Gradient Descent (type: "SGD")
- AdaDelta (type: "AdaDelta")
- Adaptive Gradient (type: "AdaGrad")
- Adam (type: "Adam")
- Nesterov’s Accelerated Gradient (type: "Nesterov")
- RMSprop (type: "RMSProp")
詳細(xì)參見(jiàn)引用1
2 Caffe類(lèi)
Caffe類(lèi)為一個(gè)包含常用的caffe成員的單例類(lèi)丁眼。如caffe使用的cuda庫(kù)cublas,curand的句柄等,以及生成Caffe中的隨機(jī)數(shù)等昭殉。
// common.hpp
// A singleton class to hold common caffe stuff, such as the handler that
// caffe is going to use for cublas, curand, etc.
class Caffe {
public:
~Caffe();
// Thread local context for Caffe. Moved to common.cpp instead of
// including boost/thread.hpp to avoid a boost/NVCC issues (#1009, #1010)
// on OSX. Also fails on Linux with CUDA 7.0.18.
static Caffe& Get();
enum Brew { CPU, GPU };
...
protected:
#ifndef CPU_ONLY
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
#endif
shared_ptr<RNG> random_generator_;
Brew mode_;
int solver_count_;
bool root_solver_;
private:
// The private constructor to avoid duplicate instantiation.
Caffe();
DISABLE_COPY_AND_ASSIGN(Caffe);
};
//common.cpp
namespace caffe {
// Make sure each thread can have different values.
static boost::thread_specific_ptr<Caffe> thread_instance_;
Caffe& Caffe::Get() {
if (!thread_instance_.get()) {
thread_instance_.reset(new Caffe());
}
return *(thread_instance_.get());
}
...
Caffe::Caffe()
: cublas_handle_(NULL), curand_generator_(NULL), random_generator_(),
mode_(Caffe::CPU), solver_count_(1), root_solver_(true) {
// Try to create a cublas handler, and report an error if failed (but we will
// keep the program running as one might just want to run CPU code).
if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create Cublas handle. Cublas won't be available.";
}
// Try to create a curand handler.
if (curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)
!= CURAND_STATUS_SUCCESS ||
curandSetPseudoRandomGeneratorSeed(curand_generator_, cluster_seedgen())
!= CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
}
}
...
} // namespace caffe
說(shuō)明:
-
Caffe
類(lèi)為一個(gè)單例類(lèi)苞七,構(gòu)造方法私有 - 該單例由
static boost::thread_specific_ptr<Caffe> thread_instance_
維護(hù)藐守,確保多線(xiàn)程環(huán)境下,不同的線(xiàn)程有不同的Caffe
單例版本 - 獲取該單例由
Get()
方法執(zhí)行蹂风,即Caffe::Get()
方法返回thread_instance_
維護(hù)的單例卢厂, -
thread_instance_
的初值為NULL
,若是第一次訪問(wèn)惠啄,則new Caffe()
-
new Caffe()
執(zhí)行構(gòu)造方法慎恒,其實(shí)只是創(chuàng)建了cublas
,curand
的句柄 - 單步調(diào)試可發(fā)現(xiàn)
cublasCreate()
創(chuàng)建cublas
的句柄,生成了額外的兩個(gè)線(xiàn)程
3 Batch
template <typename Dtype>
class Batch {
public:
Blob<Dtype> data_, label_;
};
說(shuō)明:
- Batch是對(duì)一個(gè)樣本的封裝撵渡,與Datum不同融柬,Datum是面向數(shù)據(jù)庫(kù)的,且一個(gè)Datum對(duì)應(yīng)一個(gè)樣本(圖像趋距、標(biāo)簽)丹鸿;而B(niǎo)atch是面向網(wǎng)絡(luò)的,一個(gè)Batch對(duì)應(yīng)一批樣本