embedding 的原理
embedding 層做了個什么呢弊决?它把我們的稀疏矩陣胀糜,通過一些線性變換(在CNN中用全連接層進(jìn)行轉(zhuǎn)換观蜗,也稱為查表操作)咧栗,變成了一個密集矩陣逆甜,這從稀疏矩陣到密集矩陣的過程,叫做 embedding致板,很多人也把它叫做查表交煞,因為它們之間也是一個一一映射的關(guān)系。
對 one-hot 向量的 embedding斟或,相當(dāng)于查表素征,embedding 直接用查表作為操作,而不是矩陣乘法運算缕粹,這大大降低了運算量稚茅,所以降低運算量不是因為id的embedding 向量的出現(xiàn),而是因為把 one-hot 的 embedding 矩陣乘法運算簡化為了查表操作平斩。
如下圖所示亚享,embedding 過程就是將 one-hot 向量輸入到全連接層輸出2個3維的稠密向量,這個(6, 3)的全連接層參數(shù)绘面,就是一個 id 向量表欺税,對應(yīng) 6 種 id 的 embedding 稠密向量侈沪。又例如,假設(shè)不同 id 的個數(shù)為 100(即 one-ho t向量長度為100)晚凿,設(shè)定 embedding 稠密向量的維度為 10亭罪,則全連接層的參數(shù)矩陣為100*10(這個矩陣就是 id 向量表,每個 id 特征都有一個 10 維的稠密向量表示它)歼秽。
embedding 代碼實現(xiàn)(Pytroch版本)
首先定義一個 embedding
import torch.nn as nn
# 5 輸入類別數(shù)目, 即One-hot長度, 3 輸出 embedding 稠密向量維度
my_embedding = nn.Embedding(5, 3)
查看一下embedding初始化的 weight
my_embedding.weight
從這可以看到 embedding 生成了一個5*3的矩陣应役,其實也就是 embedding 全連接層的參數(shù)。
這里以[0,1,2,3,4]為例, 假設(shè)有以下4條數(shù)據(jù)燥筷,具體特征值如下 (注意因為定義的 embedding 類別數(shù)目為5箩祥,所以輸入值不能超過4)
test = [0, 1, 2, 4]
embed = my_embedding(torch.LongTensor(test))
embed
從計算結(jié)果可以看到,embedding 之后得到的是一個4*3的矩陣肆氓,即原始特征每一個值用一個3維稠密向量表示袍祖。看到這里可能會有朋友疑問谢揪,這個4*3的矩陣具體是怎么生成的蕉陋,或者生成的依據(jù)是什么?
帶著這個問題拨扶,我們不妨回到計算之前凳鬓,如果沒有 embedding 我們該如何對一個類別型特征 one-hot, 答案很顯然,用0屈雄、1表示〈迨樱現(xiàn)在我們使用 one-hot 對上面的數(shù)據(jù)處理,可以想到酒奶,one-hot 之后預(yù)期結(jié)果如下:
test = [0,1,2,4]
one_hot(test) #這里是偽代碼蚁孔,具體 one_hot 計算邏輯不再展示
這里one-hot之后生成了一個4*5的矩陣。很顯然惋嚎,這個結(jié)果很好理解并且符合我們預(yù)期杠氢。那么這個結(jié)果和上面embedding生成的4*3矩陣有什么關(guān)系呢?
embedding 可用性理解
其實另伍,前文已經(jīng)說明過鼻百,embedding 相當(dāng)于查表。所以這里查的到底是什么表摆尝?細(xì)心的朋友可以發(fā)現(xiàn)温艇,其實查的就是我們最初定義 embedding 層的時候生成的 weight 矩陣(5*3),現(xiàn)在再回顧一下embedding 對 input 數(shù)據(jù)的計算過程堕汞,“查表”結(jié)果顯而易見勺爱。
最后為了加深我們對embedding查表邏輯的理解,我們可以嘗試對這個全連接層的參數(shù)讯检,使用矩陣乘法來計算一下琐鲁,看一下最后的計算結(jié)果:
test2 = [[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 1],
]
torch.matmul(torch.FloatTensor(test2), my_embedding.weight.data)
結(jié)果和 embedding 計算結(jié)果一致卫旱!這里也是文章最開始提到的,embedding 直接用查表作為操作围段,而不是矩陣乘法運算顾翼,這大大降低了運算量。
以上就是 embedding 在對稀疏類別特征的計算過程奈泪,這里有一點要注意适贸,最初 embedding 產(chǎn)生的 weight 可以理解為隨機的,并且整個過程并沒有進(jìn)行訓(xùn)練段磨,所以此時的 embedding 本質(zhì)僅僅是一種低維的表示向量取逾,不具有其他數(shù)據(jù)信息。
embedding 之所以強大苹支,在于 weight 本身是一個可訓(xùn)練的張量,可以接入各種網(wǎng)絡(luò)結(jié)構(gòu)误阻。所以往往 embeddin 作為網(wǎng)絡(luò)結(jié)構(gòu)的第一層债蜜,經(jīng)過中間 n 層網(wǎng)絡(luò)結(jié)構(gòu)處理(n可以為0),最后到輸出層究反。這樣在網(wǎng)絡(luò)的訓(xùn)練過程中寻定,weight 會得到更新,此時 embedding 才具有數(shù)據(jù)信息精耐,直接用這個全連接層的權(quán)重參數(shù)作為特征表達(dá)狼速。代表某一個 id,或者作為 id 的特征表達(dá)(向量的夾角余弦能夠在某種程度上表示不同id間的相似度)卦停。