目的
在文本分類中缀皱,經(jīng)常碰到一些很少出現(xiàn)過的類別或這樣不均衡的類別樣本煤傍,而且當(dāng)前的few-shot技術(shù)經(jīng)常會將輸入的query和support的樣本集合進(jìn)行sample-wise級別的對比。但是疑俭,如果跟同一個類別下的不同表達(dá)的樣本去對比的時候產(chǎn)生的效果就不太好襟齿。
因此,文章的作者就提出了肮雨,通過學(xué)習(xí)sample所屬于的類別的表示得到class-wise的向量,然后跟輸入的query進(jìn)行對比箱玷,這樣能比state-of-the-art的模型提高3%正確率怨规,同時泛化的效率也更高。
模型
模型分為三個模塊:Encoder, Induction 和 Relation. 大概的架構(gòu)如下圖.
Data:
構(gòu)建數(shù)據(jù)集的時候會把樣本分為support set—S 和 query set — Q锡足,support set就是用來訓(xùn)練參數(shù)的波丰,query set就是用來模擬真實(shí)請求,計算loss的;
support set是從C個Class中舱污,每個class抽出K個樣本生成的呀舔,那么在C個class中剩余的部分就作為query set.
Encoder Module:
Encoder階段就是將support set的文本進(jìn)行encoding; 首先弥虐,會經(jīng)過Bi-LSTM得到這樣句子的表示;
假如:support set的樣本是m (m=C * K)扩灯,LSTM輸出的表示的維度是u的話,經(jīng)過Bi-LSTM會得到H霜瘪,其維度為(m, T, 2u).
利用Self-Attention得到最終的表示珠插,也希望通過attention的方式來決定哪些hidden state, ht更值得學(xué)習(xí)。于是颖对,作者就通過將Bi-LSTM得到的表示H捻撑,經(jīng)過線性組合和tanh變換,再做Softmax處理得到attention score — a, 其維度是(m,T);
然后將a(m,T) 乘以原來的每個H(m, T, 2u)的ht缤底,并且相加顾患,得到了e矩陣,其維度變成了(m, 2u).
Induction Module:
在得到每個樣本的表示后,es矩陣(m, 2u)个唧,我們下一步需要將其向上抽象成class的表示了;
首先江解,通過matrix transformation, Ws(2u,2u),將樣本的表示進(jìn)行變形徙歼,從實(shí)驗(yàn)結(jié)果看犁河,這樣能讓不同類別的樣本區(qū)分得更好鳖枕。同時,由于matrix對于所有樣本向量都是共用的桨螺,不管什么樣的樣本size都可以支持了宾符。所以,將Ws(2u,2u)乘以es矩陣(m, 2u)得到es'(m,2u)
其次灭翔,為了確保class的表示已經(jīng)囊括了這個sample feature vector魏烫,我們還會動態(tài)地去調(diào)整這個coefficients — d, 這個d是在0,1之間分布,用來確保這個sample的類別所屬缠局。因此则奥,這里會對耦合系數(shù)b進(jìn)行softmax(在大于一定值后,隨著input的增加狭园,softmax的score的值增加得越大); 注意读处,這個耦合系數(shù)b的初始值為0,然后會通過學(xué)習(xí)來更新唱矛。(后面會提到)
然后罚舱,再通過加權(quán)聚合來得到class的表示ci',其維度是(k, 2u)
之后,通過squashing函數(shù)將ci'的表示進(jìn)行壓縮绎谦,這種壓縮不會改變正負(fù)但可以減少區(qū)間管闷,得到ci其維度是(k, 2u)
最后,回到剛才提到的b的更新窃肠,其實(shí)就是動態(tài)規(guī)劃包个,如果這個樣本是屬于這個類別的話,這個sample的向量就應(yīng)該得到更大的值冤留,而且在不同的類別的話碧囊,這個值就應(yīng)該更小纤怒;
總的來說糯而,通過多次迭代后,不但可以讓不同class之間的表示得到區(qū)分泊窘,同時熄驼,同一個class下的樣本貢獻(xiàn)程度也會通過學(xué)習(xí)后變得不一樣。同時烘豹,這里的Ws(2u,2u)也會給予后面預(yù)測去使用瓜贾。
Relation Module:
在得到了ci(k, 2u)后,我們就可以計算ci與query set的相關(guān)性分?jǐn)?shù)了携悯,作者采用的是neural tensor layer的方式祭芦。
首先,從其中一個class開始蚌卤,假設(shè)是ci(k, 2u)实束,先做一次matrix transformation奥秆, 將Ci轉(zhuǎn)置得到CiT(2u,k),然后乘以M[1:h],其維度(k,n), 得到中間結(jié)果的維度為(2u, n)咸灿,然后乘以query set, eq(n, 2u)得到結(jié)果的維度為(2u, 2u)构订,然后再過一個RELU函數(shù).
然后,將v(ci,eq)的結(jié)果經(jīng)過全聯(lián)接避矢,再經(jīng)過一個sigmoid函數(shù)悼瘾,得到一個第i個class與query的相似度
目標(biāo)函數(shù)
最后,把riq的值和yq做對审胸,如果匹配就是1亥宿,否則就是0,計算query set的loss;