Perform one iteration (epoch) of Neural Network training with mini-batch Stochastic Gradient Descent. The training targets are usually pdf-posteriors, prepared by ali-to-post.
Usage: nnet-train-frmshuff [options] <feature-rspecifier> <targets-rspecifier> <model-in> [<model-out>]
e.g.: nnet-train-frmshuff scp:feats.scp ark:posterior.ark nnet.init nnet.iter1
// main loop
while (!feature_reader.Done()) {
//填滿各個randomizer
for ( ; !feature_reader.Done(); feature_reader.Next()) {
// 一次循環(huán)讀一句話
// 特征放在feature_randomizer里
// targets放在targets_randomizer里
// 每一幀溪猿、每一句相關(guān)的weights放在weights_randomizer里
// 如果feature_randomizer被填滿的話钩杰,退出該for循環(huán),進行一次訓(xùn)練
// feature_randomizer的大小由相關(guān)NnetDataRandomizerOptions類的成員變量randomizer_size(默認(rèn)初始化為32768)決定
// 也就是feature_randomizer中一共可以存放32768幀诊县,存滿后就進行訓(xùn)練
// 可以通過參數(shù) --randomizer-size指定其大小
}
// 對feature_randomizer里的幀進行隨機重排
// 對target_randomizer和weights_randomizer也進行隨機重排
// 對randomizer里的數(shù)據(jù)進行訓(xùn)練(使用mini-batches)
// 幾個randomizer的Next()將指向?qū)嶋H數(shù)據(jù)開始位置的指針移動一個minibatch的大小
for ( ; !feature_randomizer.Done(); feature_randomizer.Next(),
targets_randomizer.Next(),
weights_randomizer.Next()){
// 拿出一個minibatch大小的feature/target對
// 跑網(wǎng)絡(luò)的前向
nnet.Propagate(nnet_in, &nnet_out);
// 根據(jù)目標(biāo)函數(shù)的類型讲弄,估計前向輸出和實際target的diff
// 支持的目標(biāo)函數(shù)類型:交叉熵xent,mse和multitask
// 以xent為例:
xent.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
// 跑網(wǎng)絡(luò)的反向(如果不是cv(交叉驗證))
nnet.Backpropagate(obj_diff, NULL);
// 如果是第一個minibatch依痊,打印網(wǎng)絡(luò)的相關(guān)信息
}
}
// 如果是最后一個minibatch避除,打印網(wǎng)絡(luò)的相關(guān)信息
// 將nnet寫到文件(如果不是cv)
// 打印和目標(biāo)函數(shù)相關(guān)的一些信息
所有的Randmoizer都根據(jù)mask進行隨機化
/**
* Generates randomly ordered vector of indices,
*/
class RandomizerMask
/**
* Shuffles rows of a matrix according to the indices in the mask,
*/
class MatrixRandomizer
下一步學(xué)習(xí)重點:
Nnet類的幾個成員函數(shù)(nnet/nnet-nnet.h)
// 跑網(wǎng)絡(luò)的前向
nnet.Propagate(nnet_in, &nnet_out);
// 根據(jù)目標(biāo)函數(shù)的類型,估計前向輸出和實際target的diff
// 支持的目標(biāo)函數(shù)類型:交叉熵xent胸嘁,mse和multitask
// 以xent為例:
xent.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff);
// 跑網(wǎng)絡(luò)的反向(如果不是做cv(交叉驗證))
nnet.Backpropagate(obj_diff, NULL);