【知識(shí)蒸餾】Knowledge Review

【GiantPandaCV引言】 知識(shí)回顧(KR)發(fā)現(xiàn)學(xué)生網(wǎng)絡(luò)深層可以通過利用教師網(wǎng)絡(luò)淺層特征進(jìn)行學(xué)習(xí)岂津,基于此提出了回顧機(jī)制卧晓,包括ABF和HCL兩個(gè)模塊芬首,可以在很多分類任務(wù)上得到一致性的提升。

摘要

知識(shí)蒸餾通過將知識(shí)從教師網(wǎng)絡(luò)傳遞到學(xué)生網(wǎng)絡(luò)逼裆,但是之前的方法主要關(guān)注提出特征變換和實(shí)施相同層的特征郁稍。

知識(shí)回顧Knowledge Review選擇研究教師與學(xué)生網(wǎng)絡(luò)之間不同層之間的路徑鏈接。

簡單來說就是研究教師網(wǎng)絡(luò)向?qū)W生網(wǎng)絡(luò)傳遞知識(shí)的鏈接方式胜宇。

代碼在:https://github.com/Jia-Research-Lab/ReviewKD

KD簡單回顧

KD最初的蒸餾對(duì)象是logits層耀怜,也即最經(jīng)典的Hinton的那篇Knowledge Distillation,讓學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)的logits KL散度盡可能小掸屡。

隨后FitNets出現(xiàn)開始蒸餾中間層封寞,一般通過使用MSE Loss讓學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)特征圖盡可能接近。

Attention Transfer進(jìn)一步發(fā)展了FitNets仅财,提出使用注意力圖來作為引導(dǎo)知識(shí)的傳遞狈究。

PKT(Probabilistic knowledge transfer for deep representation learning)將知識(shí)作為概率分布進(jìn)行建模。

Contrastive representation Distillation(CRD)引入對(duì)比學(xué)習(xí)來進(jìn)行知識(shí)遷移盏求。

以上方法主要關(guān)注于知識(shí)遷移的形式以及選擇不同的loss function抖锥,但KR關(guān)注于如何選擇教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的鏈接,一下圖為例:

image

(a-c)都是傳統(tǒng)的知識(shí)蒸餾方法碎罚,通常都是相同層的信息進(jìn)行引導(dǎo)磅废,(d)代表KR的蒸餾方式,可以使用教師網(wǎng)絡(luò)淺層特征來作為學(xué)生網(wǎng)絡(luò)深層特征的監(jiān)督荆烈,并發(fā)現(xiàn)學(xué)生網(wǎng)絡(luò)深層特征可以從教師網(wǎng)絡(luò)的淺層學(xué)習(xí)到知識(shí)拯勉。

教師網(wǎng)絡(luò)淺層到深層分別對(duì)應(yīng)的知識(shí)抽象程度不斷提高竟趾,學(xué)習(xí)難度也進(jìn)行了提升,所以學(xué)生網(wǎng)絡(luò)如果能在初期學(xué)習(xí)到教師網(wǎng)絡(luò)淺層的知識(shí)會(huì)對(duì)整體有幫助宫峦。

KR認(rèn)為淺層的知識(shí)可以作為舊知識(shí)岔帽,并進(jìn)行不斷回顧,溫故知新导绷。如何從教師網(wǎng)絡(luò)中提取多尺度信息是本文待解決的關(guān)鍵:

  • 提出了Attention based fusion(ABF) 進(jìn)行特征fusion

  • 提出了Hierarchical context loss(HCL) 增強(qiáng)模型的學(xué)習(xí)能力犀勒。

Knowledge Review

形式化描述

X是輸入圖像,S代表學(xué)生網(wǎng)絡(luò)妥曲,其中\left(\mathcal{S}_{1}, \mathcal{S}_{2}, \cdots, \mathcal{S}_{n}, \mathcal{S}_{c}\right)代表學(xué)生網(wǎng)絡(luò)各個(gè)層的組成贾费。

\mathbf{Y}_{s}=\mathcal{S}_{c} \circ \mathcal{S}_{n} \circ \cdots \circ \mathcal{S}_{1}(\mathbf{X})

Ys代表X經(jīng)過整個(gè)網(wǎng)絡(luò)以后的輸出。\left(\mathbf{F}_{s}^{1}, \cdots, \mathbf{F}_{s}^{n}\right)代表各個(gè)層中間層輸出檐盟。

那么單層知識(shí)蒸餾可以表示為:

\mathcal{L}_{S K D}=\mathcal{D}\left(\mathcal{M}_{s}^{i}\left(\mathbf{F}_{s}^{i}\right), \mathcal{M}_{t}^{i}\left(\mathbf{F}_{t}^{i}\right)\right)

M代表一個(gè)轉(zhuǎn)換褂萧,從而讓Fs和Ft的特征圖相匹配。D代表衡量兩者分布的距離函數(shù)遵堵。

同理多層知識(shí)蒸餾表示為:

\mathcal{L}_{M K D}=\sum_{i \in \mathbf{I}} \mathcal{D}\left(\mathcal{M}_{s}^{i}\left(\mathbf{F}_{s}^{i}\right), \mathcal{M}_{t}^{i}\left(\mathbf{F}_{t}^{i}\right)\right)

以上公式是學(xué)生和教師網(wǎng)絡(luò)層層對(duì)應(yīng)箱玷,那么單層KR表示方式為:

具體

與之前不同的是,這里計(jì)算的是從j=1 to i 代表第i層學(xué)生網(wǎng)絡(luò)的學(xué)習(xí)需要用到從第1到i層所有知識(shí)陌宿。

同理,多層的KR表示為:

\mathcal{L}_{M K D_{-} R}=\sum_{i \in \mathbf{I}}\left(\sum_{j=1}^{i} \mathcal{D}\left(\mathcal{M}_{s}^{i, j}\left(\mathbf{F}_{s}^{i}\right), \mathcal{M}_{t}^{j, i}\left(\mathbf{F}_{t}^{j}\right)\right)\right)

Fusion方式設(shè)計(jì)

已經(jīng)確定了KR的形式波丰,即學(xué)生每一層回顧教師網(wǎng)絡(luò)的所有靠前的層壳坪,那么最簡單的方法是:

image

直接縮放學(xué)生網(wǎng)絡(luò)最后一層feature,讓其形狀和教師網(wǎng)絡(luò)進(jìn)行匹配掰烟,這樣\mathcal{M}_s^{i,j}可以簡單使用一個(gè)卷積層配合插值層完成形狀的匹配過程爽蝴。這種方式是讓學(xué)生網(wǎng)絡(luò)更接近教師網(wǎng)絡(luò)。

image

這張圖表示擴(kuò)展了學(xué)生網(wǎng)絡(luò)所有層對(duì)應(yīng)的處理方式纫骑,也即按照第一張圖的處理方式進(jìn)行形狀匹配蝎亚。

這種處理方式可能并不是最優(yōu)的,因?yàn)闀?huì)導(dǎo)致stage之間出現(xiàn)巨大的差異性先馆,同時(shí)處理過程也非常復(fù)雜发框,帶來了額外的計(jì)算代價(jià)。

為了讓整個(gè)過程更加可行煤墙,提出了Attention based fusion \mathcal{U}, 這樣整體蒸餾變?yōu)椋?/p>

\sum_{i=j}^{n} \mathcal{D}\left(\mathbf{F}_{s}^{i}, \mathbf{F}_{t}^{j}\right) \approx \mathcal{D}\left(\mathcal{U}\left(\mathbf{F}_{s}^{j}, \cdots, \mathbf{F}_{s}^{n}\right), \mathbf{F}_{t}^{j}\right)

如果引入了fusion的模塊梅惯,那整體流程就變?yōu)橄聢D所示:

image

但是為了更高的效率,再對(duì)其進(jìn)行改進(jìn):

image

可以發(fā)現(xiàn)仿野,這個(gè)過程將fusion的中間結(jié)果進(jìn)行了利用铣减,即\mathbf{F}_{s}^{j} \text { and } \mathcal{U}\left(\mathbf{F}_{s}^{j+1}, \cdots, \mathbf{F}_{s}^{n}\right), 這樣循環(huán)從后往前進(jìn)行迭代,就可以得到最終的loss脚作。

具體來說葫哗,ABF的設(shè)計(jì)如下(a)所示,采用了注意力機(jī)制融合特征,具體來說中間的1x1 conv對(duì)兩個(gè)level的feature提取綜合空間注意力特征圖劣针,然后再進(jìn)行特征重標(biāo)定校镐,可以看做SKNet的空間注意力版本。

image

而HCL Hierarchical context loss 這里對(duì)分別來自于學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)的特征進(jìn)行了空間池化金字塔的處理酿秸,L2 距離用于衡量兩者之間的距離灭翔。

KR認(rèn)為這種方式可以捕獲不同level的語義信息,可以在不同的抽象等級(jí)提取信息辣苏。

實(shí)驗(yàn)

實(shí)驗(yàn)部分主要關(guān)注消融實(shí)驗(yàn):

第一個(gè)是使用不同stage的結(jié)果:

image

藍(lán)色的值代表比baseline 69.1更好肝箱,紅色代表要比baseline更差。通過上述結(jié)果可以發(fā)現(xiàn)使用教師網(wǎng)絡(luò)淺層知識(shí)來監(jiān)督學(xué)生網(wǎng)絡(luò)深層知識(shí)是有效的稀蟋。

第二個(gè)是各個(gè)模塊的作用:

image

源碼

主要關(guān)注ABF煌张, HCL的實(shí)現(xiàn):

ABF實(shí)現(xiàn):

class ABF(nn.Module):
    def __init__(self, in_channel, mid_channel, out_channel, fuse):
        super(ABF, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channel),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_channel, out_channel,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(out_channel),
        )
        if fuse:
            self.att_conv = nn.Sequential(
                    nn.Conv2d(mid_channel*2, 2, kernel_size=1),
                    nn.Sigmoid(),
                )
        else:
            self.att_conv = None
        nn.init.kaiming_uniform_(self.conv1[0].weight, a=1)  # pyre-ignore
        nn.init.kaiming_uniform_(self.conv2[0].weight, a=1)  # pyre-ignore

    def forward(self, x, y=None, shape=None, out_shape=None):
        n,_,h,w = x.shape
        # transform student features
        x = self.conv1(x)
        if self.att_conv is not None:
            # upsample residual features
            y = F.interpolate(y, (shape,shape), mode="nearest")
            # fusion
            z = torch.cat([x, y], dim=1)
            z = self.att_conv(z)
            x = (x * z[:,0].view(n,1,h,w) + y * z[:,1].view(n,1,h,w))
        # output 
        if x.shape[-1] != out_shape:
            x = F.interpolate(x, (out_shape, out_shape), mode="nearest")
        y = self.conv2(x)
        return y, x

HCL實(shí)現(xiàn):

def hcl(fstudent, fteacher):
# 兩個(gè)都是list,存各個(gè)stage對(duì)象
    loss_all = 0.0
    for fs, ft in zip(fstudent, fteacher):
        n,c,h,w = fs.shape
        loss = F.mse_loss(fs, ft, reduction='mean')
        cnt = 1.0
        tot = 1.0
        for l in [4,2,1]:
            if l >=h:
                continue
            tmpfs = F.adaptive_avg_pool2d(fs, (l,l))
            tmpft = F.adaptive_avg_pool2d(ft, (l,l))
            cnt /= 2.0
            loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnt
            tot += cnt
        loss = loss / tot
        loss_all = loss_all + loss
    return loss_all

ReviewKD實(shí)現(xiàn):

class ReviewKD(nn.Module):
    def __init__(
        self, student, in_channels, out_channels, shapes, out_shapes,
    ):  
        super(ReviewKD, self).__init__()
        self.student = student
        self.shapes = shapes
        self.out_shapes = shapes if out_shapes is None else out_shapes

        abfs = nn.ModuleList()

        mid_channel = min(512, in_channels[-1])
        for idx, in_channel in enumerate(in_channels):
            abfs.append(ABF(in_channel, mid_channel, out_channels[idx], idx < len(in_channels)-1))
        self.abfs = abfs[::-1]
        self.to('cuda')

    def forward(self, x):
        student_features = self.student(x,is_feat=True)
        logit = student_features[1]
        x = student_features[0][::-1]
        results = []
        out_features, res_features = self.abfs[0](x[0], out_shape=self.out_shapes[0])
        results.append(out_features)
        for features, abf, shape, out_shape in zip(x[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:]):
            out_features, res_features = abf(features, res_features, shape, out_shape)
            results.insert(0, out_features)

        return results, logit

參考

https://zhuanlan.zhihu.com/p/363994781

https://arxiv.org/pdf/2104.09044.pdf

https://github.com/dvlab-research/ReviewKD

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末退客,一起剝皮案震驚了整個(gè)濱河市骏融,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌萌狂,老刑警劉巖档玻,帶你破解...
    沈念sama閱讀 218,607評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異茫藏,居然都是意外死亡误趴,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,239評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門务傲,熙熙樓的掌柜王于貴愁眉苦臉地迎上來凉当,“玉大人,你說我怎么就攤上這事售葡】春迹” “怎么了?”我有些...
    開封第一講書人閱讀 164,960評(píng)論 0 355
  • 文/不壞的土叔 我叫張陵挟伙,是天一觀的道長楼雹。 經(jīng)常有香客問我,道長像寒,這世上最難降的妖魔是什么烘豹? 我笑而不...
    開封第一講書人閱讀 58,750評(píng)論 1 294
  • 正文 為了忘掉前任,我火速辦了婚禮诺祸,結(jié)果婚禮上携悯,老公的妹妹穿的比我還像新娘。我一直安慰自己筷笨,他們只是感情好憔鬼,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,764評(píng)論 6 392
  • 文/花漫 我一把揭開白布龟劲。 她就那樣靜靜地躺著,像睡著了一般轴或。 火紅的嫁衣襯著肌膚如雪昌跌。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,604評(píng)論 1 305
  • 那天照雁,我揣著相機(jī)與錄音蚕愤,去河邊找鬼。 笑死饺蚊,一個(gè)胖子當(dāng)著我的面吹牛萍诱,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播污呼,決...
    沈念sama閱讀 40,347評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼裕坊,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了燕酷?” 一聲冷哼從身側(cè)響起籍凝,我...
    開封第一講書人閱讀 39,253評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎苗缩,沒想到半個(gè)月后饵蒂,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,702評(píng)論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡酱讶,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,893評(píng)論 3 336
  • 正文 我和宋清朗相戀三年苹享,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片浴麻。...
    茶點(diǎn)故事閱讀 40,015評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖囤攀,靈堂內(nèi)的尸體忽然破棺而出软免,到底是詐尸還是另有隱情,我是刑警寧澤焚挠,帶...
    沈念sama閱讀 35,734評(píng)論 5 346
  • 正文 年R本政府宣布膏萧,位于F島的核電站,受9級(jí)特大地震影響蝌衔,放射性物質(zhì)發(fā)生泄漏榛泛。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,352評(píng)論 3 330
  • 文/蒙蒙 一噩斟、第九天 我趴在偏房一處隱蔽的房頂上張望曹锨。 院中可真熱鬧,春花似錦剃允、人聲如沸沛简。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,934評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽椒楣。三九已至给郊,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間捧灰,已是汗流浹背淆九。 一陣腳步聲響...
    開封第一講書人閱讀 33,052評(píng)論 1 270
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留毛俏,地道東北人炭庙。 一個(gè)月前我還...
    沈念sama閱讀 48,216評(píng)論 3 371
  • 正文 我出身青樓,卻偏偏與公主長得像拧抖,于是被迫代替她去往敵國和親煤搜。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,969評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容