解碼就是輸入音頻,利用聲學(xué)模型案腺、構(gòu)建好的WFST解碼網(wǎng)絡(luò)啰脚,輸出最優(yōu)狀態(tài)序列的過程殷蛇。以Kaldi中LatticeFasterOnlineDecoder為例,解析解碼代碼橄浓。
示例程序:
online2-wav-nnet3-latgen-faster --do-endpointing=false --online=false --frame-subsampling-factor=3
--config=conf/online.conf --max-active=7000 --beam=15.0 --frames-per-chunk=50 --lattice-beam=6.0
--acoustic-scale=1.0 --word-symbol-table=words.txt final.mdl HCLG.fst ark:spk2utt.txt scp:test.scp ark,t:lat.debug.txt
聲學(xué)模型:final.mdl Kaldi Chain model 文件解析
WFST:HCLG.fst
spk2utt.txt 內(nèi)容如下:
wav10 wav10
wav9 wav9
test.scp 內(nèi)容如下:
wav10 data/wav/00030/2017_03_07_16.57.22_1175.wav
wav9 data/wav/00030/2017_03_07_16.57.40_2562.wav
主要數(shù)據(jù)結(jié)構(gòu):
- Token
struct Token {
BaseFloat tot_cost; // 到該狀態(tài)的累計最優(yōu)cost
BaseFloat extra_cost; //token所有ForwardLinks中和最優(yōu)路徑的cost差的最小值粒梦,PruneActiveTokens 用到
ForwardLink *links; // 鏈表,表示現(xiàn)在時刻到下一時刻的那條跳轉(zhuǎn)邊
Token *next; // 指向同一時刻的下一個token
Token *backpointer; // 指向上一時刻的最佳token荸实,相當(dāng)于一個回溯指針
};
- ForwardLink
struct ForwardLink {
Token *next_tok; // 這條鏈接指向的token
Label ilabel; // 這下面的四個量取自解碼圖中的跳轉(zhuǎn)/弧/邊匀们,因為每一個狀態(tài)
Label olabel; // 維護(hù)一個token,那么token到token之間的連接信息和狀態(tài)到狀態(tài)之間的信息
BaseFloat graph_cost; // 應(yīng)該保持一致准给,所以會有輸入(tid)泄朴,輸出,權(quán)值(就是graph_cost)
BaseFloat acoustic_cost; // acoustic_cost就是tid對應(yīng)的pdf_id的在聲學(xué)模型中的后驗
ForwardLink *next; // 鏈表結(jié)構(gòu)露氮,指向下一個
};
- TokenList
struct TokenList {
Token *toks; // 同一時刻的token鏈表頭
bool must_prune_forward_links; // 這兩個是Lattice剪枝標(biāo)記祖灰,起始默認(rèn)設(shè)置為true
bool must_prune_tokens;
};
- HashList
template<class I, class T> class HashList {
struct Elem {
I key; // state
T val; // Token
Elem *tail;
};
struct HashBucket {
size_t prev_bucket; // 指向下一個桶,最后一個指向-1
Elem *last_elem; // 指向掛在桶上的最后一個元素畔规,空桶指向NULL
};
Elem *list_head_; // 鏈表頭
size_t bucket_list_tail_; // 當(dāng)前活躍桶最后一個下標(biāo)
size_t hash_size_; // 當(dāng)前活躍桶個數(shù)
std::vector<HashBucket> buckets_; //存儲實際活躍的桶
Elem *freed_head_; // head of list of currently freed elements. [ready for allocation]
std::vector<Elem*> allocated_; // list of allocated blocks.
};
解碼過程中上述數(shù)據(jù)結(jié)構(gòu)對應(yīng)的一些重要變量如下(來自decoder/lattice-faster-online-decoder.h)
HashList<StateId, Token*> toks_;
std::vector<TokenList> active_toks_; // 每一幀對應(yīng)其中一個TokenList局扶,等于frame+1,
std::vector<StateId> queue_; // 臨時變量油讯,用于ProcessNonemitting详民,保存的是下一時刻state
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
解碼整體流程:
- 模型、文件加載陌兑,配置生成沈跨;
- 三層循環(huán)
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { //循環(huán)speaker
...
const std::vector<std::string> &uttlist = spk2utt_reader.Value();
for (size_t i = 0; i < uttlist.size(); i++) { //循環(huán)某個speaker的所有wav
SingleUtteranceNnet3Decoder decoder(decoder_opts, trans_model, decodable_info, *decode_fst, &feature_pipeline); //構(gòu)造函數(shù)中調(diào)用InitDecoding()
//循環(huán)某個wav的chunk,比如說一幀一幀兔综,online=false的時候一次加載整個wav
while (samp_offset < data.Dim()) {
decoder.AdvanceDecoding();
}
decoder.FinalizeDecoding();
decoder.GetLattice(end_of_utterance, &clat);
GetDiagnosticsAndPrintOutput(utt, word_syms, clat,&num_frames, &tot_like);
}
}
對于單個wav饿凛,最主要流程就是三個函數(shù):
void InitDecoding();
void LatticeFasterOnlineDecoder::AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
void FinalizeDecoding();
其中AdvanceDecoding主流程如下圖狞玛,每幀數(shù)據(jù)處理流程包括:
BaseFloat ProcessEmittingWrapper(DecodableInterface *decodable);
實際調(diào)用LatticeFasterOnlineDecoder::ProcessEmitting<fst::VectorFst<Arc>>(decodable);
處理輸入非空跳轉(zhuǎn)(ilabel != 0),主體兩層循環(huán)涧窒,外層循環(huán)現(xiàn)在時刻所有Token心肪,內(nèi)層循環(huán)每個現(xiàn)在時刻的state能夠跳轉(zhuǎn)的下一時刻所有state。
ProcessEmitting 函數(shù)中vector active_toks_ 加1(active_toks_.resize(active_toks_.size() + 1);)纠吴,另外硬鞍,NumFramesDecoded() 返回值等于active_toks_.size() - 1。void ProcessNonemittingWrapper(BaseFloat cost_cutoff);
實際調(diào)用LatticeFasterOnlineDecoder::ProcessNonemitting<fst::VectorFst<Arc>>(cost_cutoff);
處理輸入空跳轉(zhuǎn)(ilabel == 0)戴已,主體兩層循環(huán)固该,外層循環(huán)下一時刻所有Token,內(nèi)層循環(huán)每個下一時刻的state能夠跳轉(zhuǎn)到的的state糖儡》セ担可以這樣理解,下一時刻的空跳轉(zhuǎn)還是現(xiàn)在時刻通過一幀能夠到達(dá)的時刻握联。void PruneActiveTokens(BaseFloat delta);
lattice beam 剪枝桦沉,默認(rèn)25幀一次,包括兩部分:剪枝ForwardLinks(PruneForwardLinks函數(shù))金闽,剪枝Tokens(PruneTokensForFrame函數(shù))
- 打印統(tǒng)計信息
主要函數(shù)解析:
- ProcessEmitting (decoder/lattice-faster-online-decoder.cc)
template <typename FstType>
BaseFloat LatticeFasterOnlineDecoder::ProcessEmitting(
DecodableInterface *decodable) {
KALDI_ASSERT(active_toks_.size() > 0);
int32 frame = active_toks_.size() - 1;
active_toks_.resize(active_toks_.size() + 1); //每幀+1纯露,外層調(diào)用的while循環(huán)也是
Elem *final_toks = toks_.Clear(); // 此處clear的是bucket,返回鏈表頭呐矾,遍歷可得現(xiàn)在時刻所有state的鏈表
Elem *best_elem = NULL;
BaseFloat adaptive_beam;
size_t tok_cnt;
// Beam prune 參數(shù)獲取苔埋,包括cur_cutoff懦砂,adaptive_beam, best_elem蜒犯。 后兩者用來確定next_cutoff
// 主要是兩個條件,默認(rèn)是best_weight + config_.beam荞膘,同時用config_.max_active罚随、config_.min_active 做了加強,希望state數(shù)目在[config_.min_active, config_.max_active]之間
BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good
const FstType &fst = dynamic_cast<const FstType&>(fst_);
// 下面這個塊只是為了得到next_cutoff and cost_offset.
// next_cutoff 用于下一時刻state的beam prune羽资。等于現(xiàn)在時刻最優(yōu)state到下一時刻對應(yīng)所有state中最優(yōu)的tot_cost
// cost_offset 只是為了計算方面的考慮淘菩,相當(dāng)于同時減了一個最小數(shù)。
if (best_elem) {
StateId state = best_elem->key;
Token *tok = best_elem->val;
cost_offset = - tok->tot_cost;
for (fst::ArcIterator<FstType> aiter(fst, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // propagate..
BaseFloat new_weight = arc.weight.Value() + cost_offset -
decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost; // 這一步cost_offset + tok_tot_cost === 0屠升,可以不要
if (new_weight + adaptive_beam < next_cutoff)
next_cutoff = new_weight + adaptive_beam;
}
}
}
...
// the tokens are now owned here, in final_toks, and the hash is empty.
// 'owned' is a complex thing here; the point is we need to call DeleteElem
// on each elem 'e' to let toks_ know we're done with them.
for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { //外層循環(huán)潮改,遍歷現(xiàn)在時刻state
// loop this way because we delete "e" as we go.
StateId state = e->key;
Token *tok = e->val;
if (tok->tot_cost <= cur_cutoff) { // 現(xiàn)在時刻beam prune,tot_cost控制在cur_cutoff閾值以內(nèi)腹暖,cur_cutoff=現(xiàn)在時刻最優(yōu)state tot_cost+beam
for (fst::ArcIterator<FstType> aiter(fst, state); // 內(nèi)層循環(huán)汇在,遍歷現(xiàn)在時刻某個state的所有跳轉(zhuǎn)
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // 輸入非空跳轉(zhuǎn)
BaseFloat ac_cost = cost_offset -
decodable->LogLikelihood(frame, arc.ilabel),
graph_cost = arc.weight.Value(),
cur_cost = tok->tot_cost,
tot_cost = cur_cost + ac_cost + graph_cost;
if (tot_cost > next_cutoff) continue;
// 下一時刻beam prune,下一時刻tot_cost控制在閾值next_cutoff之內(nèi)脏答。
// next_cutoff糕殉,初始值為:現(xiàn)在時刻最優(yōu)state到下一時刻所有state中最優(yōu)cost+adaptive_beam亩鬼。注意不是下一時刻所有state中最優(yōu)cost+adaptive_beam,后面再動態(tài)調(diào)整阿蝶。
else if (tot_cost + adaptive_beam < next_cutoff)
next_cutoff = tot_cost + adaptive_beam;
//擴展下一時刻token,存取在toks_中雳锋,這一幀的ProcessNonemitting就是在toks_對應(yīng)的list中循環(huán)。所以說ProcessNonemitting循環(huán)的是下一時刻的state以及下一時刻state的擴展跳轉(zhuǎn)羡洁。
Token *next_tok = FindOrAddToken(arc.nextstate,frame + 1, tot_cost, tok, NULL);
// 加邊玷过。Add ForwardLink from tok to next_tok (put on head of list tok->links)
tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
graph_cost, ac_cost, tok->links);
}
} // for all arcs
}
e_tail = e->tail;
toks_.Delete(e); // delete Elem
}
return next_cutoff;
}
主體流程是雙層循環(huán),也就是Viterbi解碼筑煮,外層循環(huán)現(xiàn)在時刻所有state冶匹,內(nèi)層循環(huán)每個state對應(yīng)的每個跳轉(zhuǎn),確定下一時刻所有state咆瘟。過程中生成state對應(yīng)的Token以及ForwardLink嚼隘。同時用到了Beam Prune,現(xiàn)在時刻和下一時刻都有應(yīng)用袒餐。
ProcessNonemitting(BaseFloat cutoff) (decoder/lattice-faster-online-decoder.cc)
首先遍歷前面ProcessEmitting函數(shù)生成的HashList飞蛹,得到現(xiàn)在時刻state 隊列 queue_
然后兩層遍歷:外層遍歷queue_,內(nèi)層遍歷stata的空跳轉(zhuǎn)灸眼;
注意一點的是:frame = static_cast<int32>(active_toks_.size()) - 2 卧檐,這個如果不注意,理解內(nèi)循環(huán)中的FindOrAddToken函數(shù)會出現(xiàn)偏差焰宣。FindOrAddToken
構(gòu)造Token霉囚,插入到active_toks_[frame_plus_one].toks指向的Token list中,插入到HashList toks_中
inline LatticeFasterOnlineDecoder::Token *LatticeFasterOnlineDecoder::FindOrAddToken(
StateId state, int32 frame_plus_one, BaseFloat tot_cost,
Token *backpointer, bool *changed) {
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
KALDI_ASSERT(frame_plus_one < active_toks_.size());
Token *&toks = active_toks_[frame_plus_one].toks; // 引用匕积,注意后面的改變其實改變了右邊的值
Elem *e_found = toks_.Find(state); //HashList中查找
if (e_found == NULL) { // no such token presently.
const BaseFloat extra_cost = 0.0;
Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); //構(gòu)造Token盈罐,頭插
toks = new_tok;
num_toks_++;
toks_.Insert(state, new_tok); //toks_是一個HashList,ProcessNonemitting函數(shù)或者下一幀會用到
if (changed) *changed = true;
return new_tok;
} else {
Token *tok = e_found->val; // There is an existing Token for this state.
if (tok->tot_cost > tot_cost) { // replace old token
tok->tot_cost = tot_cost;
tok->backpointer = backpointer;
if (changed) *changed = true;
} else {
if (changed) *changed = false;
}
return tok;
}
}
- GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem)
Viterbi解碼中涉及到現(xiàn)在時刻state數(shù)目以及下一時刻state數(shù)目闪唆,如果我們想要提高解碼速度盅粪,需要對這兩個數(shù)值都做縮減。實際做法是設(shè)置閾值悄蕾,減少語音識別中現(xiàn)在時刻以及下一時刻狀態(tài)數(shù)目票顾,具體做法是:** 首先求現(xiàn)在時刻最優(yōu)路徑得分,加上beam帆调,得到現(xiàn)在時刻得分閾值奠骄;然后求下一時刻最優(yōu)路徑得分,加上beam番刊,得到下一時刻得分閾值**含鳞;具體步驟是:
- 對所有狀態(tài)排序,最優(yōu)狀態(tài)放最前面撵枢,最優(yōu)狀態(tài)得分=best_weight
- 設(shè)置一個beam民晒,設(shè)置閾值1=cur_cutoff精居,cur_cutoff=best_weight+beam,所有得分在cur_cutoff以內(nèi)的潜必,保留靴姿,反之丟棄,現(xiàn)在時刻的state數(shù)目減少磁滚。
- 計算到下一時刻的最優(yōu)路徑得分new_weight佛吓。
- 設(shè)置一個adaptive_beam, 設(shè)置閾值2=next_cutoff,next_cutoff=new_weight+adaptive_beam垂攘,所有得分在next_cutoff以內(nèi)的维雇,保留,反之丟棄晒他,下一時刻的state數(shù)目減少吱型。
注意上述步驟中的beam不是參數(shù)傳遞進(jìn)去的config_.beam;因為我們?nèi)绻苯佑胏onfig_.beam陨仅,有可能卡出的state數(shù)目太多(大于config_.max_active)或者太少(少于config_.min_active)津滞。所以需要分類討論,確定最終的beam值灼伤,adaptive_beam類似触徐。
cur_cutoff,adaptive_beam 都是來自GetCutoff函數(shù):
// BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
// 輸入final_toks狐赡,HashList對應(yīng)的list撞鹉,toks_.Clear() 操作后的得頭結(jié)點指向
// 輸出 cur_cutoff,返回值颖侄,用于現(xiàn)在時刻Beam Prune
// 輸出 adaptive_beam, best_elem 得到next_cutoff鸟雏,用于下一時刻Beam Prune
// 輸出 tok_cnt 用于重置HashList toks_大小 ,足夠大发皿,減少內(nèi)存分配時間
PossiblyResizeHash(tok_cnt)
BaseFloat LatticeFasterOnlineDecoder::GetCutoff(Elem *list_head, size_t *tok_count,
BaseFloat *adaptive_beam, Elem **best_elem)
- PruneActiveTokens
從后向前崔慧,主要做兩步操作:
PruneForwardLinks拂蝎,刪減Token的ForwordLinks穴墅,
PruneTokensForFrame,刪減Token本身温自,如果該Token對應(yīng)的所有的ForwardLinks 都沒有了玄货,那Token本身也可以刪除,判斷條件tok->extra_cost == std::numeric_limits<BaseFloat>::infinity()悼泌,extra_cost代表該tok所有ForwardLinks到的next state 的tot_cost和到達(dá)該next state最優(yōu)路徑的tot_cost差的最小值松捉,如果是無窮大(最小值都是無窮大)代表所有ForwordLinks都刪除了。
Reference
http://www.funcwj.cn/2017/08/02/kaldi-online-decoder/
https://blog.csdn.net/u013677156/article/details/78930532