自蒸餾整體網(wǎng)絡(luò)結(jié)構(gòu):
其中乡革,bottleneck可減輕每個(gè)淺分類器之間的影響,添加teacher隱藏層L2 loss,并且使teacher與student網(wǎng)絡(luò)feature map輸出大小一致摊腋。
三個(gè)損失函數(shù):
- 交叉熵?fù)p失(從標(biāo)簽到最深分類器和淺分類器):根據(jù)數(shù)據(jù)集標(biāo)簽與分類器softmax輸出進(jìn)行計(jì)算
- KL散度:計(jì)算teacher與student 之間的softmax
- L2 loss:計(jì)算最深分類器與淺分類器feature map 之間的 L2 loss
總體損失:
C表示CNN中分類器個(gè)數(shù)
其中沸版,最深分類器的λ和α為零,即最深分類器的監(jiān)督僅來自標(biāo)簽兴蒸。
注意
- 自蒸餾存在梯度消失的問題推穷,因此較深的神經(jīng)網(wǎng)絡(luò)較難訓(xùn)練
- 自蒸餾一種提高模型性能的訓(xùn)練技術(shù),而不是一種壓縮模型的方法