MXNet/Gluon 中 Triplet Loss 算法

Triplet Loss轩猩,即三元組損失肿轨,用于訓(xùn)練差異性較小的數(shù)據(jù)集上炎,數(shù)據(jù)集中標(biāo)簽較多恃逻,標(biāo)簽的樣本較少。輸入數(shù)據(jù)包括錨(Anchor)示例??、正(Positive)示例負(fù)(Negative)示例辛块,通過優(yōu)化模型,使得錨示例與正示例的距離小于錨示例與負(fù)示例的距離铅碍,實現(xiàn)樣本的相似性計算润绵。其中錨示例是樣本集中隨機選取的一個樣本,正示例與錨示例屬于同一類的樣本胞谈,而負(fù)示例與錨示例屬于不同類的樣本尘盼。

歡迎Follow我的GitHubhttps://github.com/SpikeKing

Triplet Loss

在訓(xùn)練Triplet Loss模型時,只需要輸入樣本烦绳,不需要輸入標(biāo)簽卿捎,這樣避免標(biāo)簽過多、同標(biāo)簽樣本過少的問題径密,模型只關(guān)心樣本編碼午阵,不關(guān)心樣本類別。Triplet Loss在相似性計算和檢索中的效果較好享扔,可以學(xué)習(xí)到樣本與變換樣本之間的關(guān)聯(lián)底桂,檢索出與當(dāng)前樣本最相似的其他樣本。

Triplet Loss通常應(yīng)用于個體級別的細粒度識別惧眠,比如分類貓與狗等是大類別的識別籽懦,但是有些需求要精確至個體級別,比如識別不同種類不同配色的貓??等氛魁,所以Triplet Loss最主要的應(yīng)用也是在細粒度檢索領(lǐng)域中暮顺。

Triplet Loss的對比:

  • 如果把不同個體作為類別進行分類訓(xùn)練,Softmax維度可能遠大于Feature維度秀存,精度無法保證捶码。
  • Triplet Loss一般比分類能學(xué)習(xí)到更好的特征,在度量樣本距離時或链,效果較好宙项;
  • Triplet Loss支持調(diào)整閾值Margin,控制正負(fù)樣本的距離株扛,當(dāng)特征歸一化之后尤筐,通過調(diào)節(jié)閾值提升置信度。

Triplet Loss的公式

公式

其他請參考Triplet Loss算法的論文洞就。

本文使用MXNet/Gluon深度學(xué)習(xí)框架盆繁,數(shù)據(jù)集選用MNIST,實現(xiàn)Triplet Loss算法旬蟋。

本文的源碼https://github.com/SpikeKing/triplet-loss-gluon


數(shù)據(jù)集

安裝MXNet庫:

pip install mxnet

推薦豆瓣源下載油昂,速度較快,-i https://pypi.douban.com/simple

MNIST就是著名的手寫數(shù)字識別庫,其中包含0至9等10個數(shù)字的手寫體冕碟,圖片大小為28*28的灰度圖拦惋,目標(biāo)是根據(jù)圖片識別正確的數(shù)字。

使用MNIST類加載數(shù)據(jù)集安寺,獲取訓(xùn)練集mnist_train和測試集mnist_test的數(shù)據(jù)和標(biāo)簽厕妖。

mnist_train = MNIST(train=True)  # 加載訓(xùn)練
tr_data = mnist_train._data.reshape((-1, 28 * 28))  # 數(shù)據(jù)
tr_label = mnist_train._label  # 標(biāo)簽

mnist_test = MNIST(train=False)  # 加載測試
te_data = mnist_test._data.reshape((-1, 28 * 28))  # 數(shù)據(jù)
te_label = mnist_test._label  # 標(biāo)簽

Triplet Loss訓(xùn)練的一個關(guān)鍵步驟就是準(zhǔn)備訓(xùn)練數(shù)據(jù)。本例繼承Dataset類創(chuàng)建Triplet的數(shù)據(jù)集類TripletDataset

  1. 在構(gòu)造器中:
    • 傳入原始數(shù)據(jù)rd挑庶、原始標(biāo)簽rl言秸;
    • _data_label是標(biāo)準(zhǔn)的數(shù)據(jù)和標(biāo)簽變量;
    • _transform是標(biāo)準(zhǔn)的轉(zhuǎn)換變量迎捺;
    • 調(diào)用_get_data()举畸,完成_data_label的賦值;
  2. __getitem__是數(shù)據(jù)處理接口凳枝,根據(jù)索引idx返回數(shù)據(jù)抄沮,支持調(diào)用_transform執(zhí)行數(shù)據(jù)轉(zhuǎn)換;
  3. __len__是數(shù)據(jù)的總數(shù)岖瑰;
  4. _get_data()是數(shù)據(jù)賦值的核心方法:
    • 分離索引合是,獲取標(biāo)簽相同數(shù)據(jù)的索引值Index列表digit_indices
    • 創(chuàng)建三元組锭环,即錨示例聪全、正示例和負(fù)示例的索引組合矩陣;
    • 數(shù)據(jù)是三元組辅辩,標(biāo)簽是ones矩陣难礼,因為標(biāo)簽在Triplet Loss中沒有實際意義;

具體實現(xiàn):

class TripletDataset(dataset.Dataset):
    def __init__(self, rd, rl, transform=None):
        self.__rd = rd  # 原始數(shù)據(jù)
        self.__rl = rl  # 原始標(biāo)簽
        self._data = None
        self._label = None
        self._transform = transform
        self._get_data()

    def __getitem__(self, idx):
        if self._transform is not None:
            return self._transform(self._data[idx], self._label[idx])
        return self._data[idx], self._label[idx]

    def __len__(self):
        return len(self._label)

    def _get_data(self):
        label_list = np.unique(self.__rl)
        digit_indices = [np.where(self.__rl == i)[0] for i in label_list]
        tl_pairs = create_pairs(self.__rd, digit_indices, len(label_list))
        self._data = tl_pairs
        self._label = mx.nd.ones(tl_pairs.shape[0])

create_pairs()是創(chuàng)建三元組的核心邏輯:

  1. 確定不同標(biāo)簽的選擇樣本數(shù)玫锋,選擇最少的標(biāo)簽樣本數(shù)蛾茉;
  2. 將標(biāo)簽d的索引值隨機洗牌(Shuffle),選擇樣本i和i+1作為錨和正示例;
  3. 隨機選擇(Randrange)其他標(biāo)簽dn中的樣本i作為負(fù)示例;
  4. 循環(huán)全部標(biāo)簽和全部樣本撩鹿,生成含有錨谦炬、正、負(fù)示例的隨機組合节沦。

這樣所創(chuàng)建的組合矩陣键思,保證樣本的分布均勻,既避免組合過大(對比于全排列)甫贯,又引入足夠的隨機性(雙重隨機)吼鳞。注意:由于滑動窗口為2,即i和i+1叫搁,則19個樣本生成18個樣本組赔桌。

具體實現(xiàn)供炎,如下:

@staticmethod
def create_pairs(x, digit_indices, num_classes):
    x = x.asnumpy()  # 轉(zhuǎn)換數(shù)據(jù)格式
    pairs = []
    n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1  # 最小類別數(shù)
    for d in range(num_classes):
        for i in range(n):
            np.random.shuffle(digit_indices[d])
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            inc = random.randrange(1, num_classes)
            dn = (d + inc) % num_classes
            z3 = digit_indices[dn][i]
            pairs += [[x[z1], x[z2], x[z3]]]
    return np.asarray(pairs))

使用DataLoader將TripletDataset封裝為迭代器train_datatest_data,支持按批次batch輸出樣本疾党。train_data用于訓(xùn)練網(wǎng)絡(luò)音诫,test_data用于驗證網(wǎng)絡(luò)。

def transform(data_, label_):
    return data_.astype(np.float32) / 255., label_.astype(np.float32)

train_data = DataLoader(
    TripletDataset(rd=tr_data, rl=tr_label, transform=transform),
    batch_size, shuffle=True)

test_data = DataLoader(
    TripletDataset(rd=te_data, rl=te_label, transform=transform),
    batch_size, shuffle=True)

網(wǎng)絡(luò)和訓(xùn)練

Triplet Loss的基礎(chǔ)網(wǎng)絡(luò)雪位,選用非常簡單的多層感知機竭钝,主要為了驗證Triplet Loss的效果。

base_net = Sequential()
with base_net.name_scope():
    base_net.add(Dense(256, activation='relu'))
    base_net.add(Dense(128, activation='relu'))
    
base_net.collect_params().initialize(mx.init.Uniform(scale=0.1), ctx=ctx)

初始化參數(shù)茧泪,使用uniform均勻分布,范圍是[-0.1, 0.1]聋袋,效果類似如下:

Uniform

Gluon中自帶TripletLoss損失函數(shù)队伟,非常贊??,產(chǎn)學(xué)結(jié)合的非常好幽勒!初始化損失函數(shù)triplet_loss和訓(xùn)練器trainer_triplet嗜侮。

triplet_loss = gluon.loss.TripletLoss()  # TripletLoss損失函數(shù)
trainer_triplet = gluon.Trainer(base_net.collect_params(), 'sgd', {'learning_rate': 0.05})

Triplet Loss的訓(xùn)練過程:

  1. 循環(huán)執(zhí)行epoch,共10輪啥容;
  2. train_data迭代輸出每個批次的訓(xùn)練數(shù)據(jù)data锈颗;
  3. 指定訓(xùn)練的執(zhí)行環(huán)境as_in_context(),MXNet的數(shù)據(jù)環(huán)境就是訓(xùn)練環(huán)境咪惠;
  4. 數(shù)據(jù)來源于TripletDataset击吱,可以直接分為三個示例;
  5. 三個示例共享模型base_net遥昧,計算triplet_loss的損失函數(shù)覆醇;
  6. 調(diào)用loss.backward(),反向傳播求導(dǎo)炭臭;
  7. 設(shè)置訓(xùn)練器trainer_triplet的step是batch_size永脓;
  8. 計算損失函數(shù)的均值curr_loss
  9. 使用測試數(shù)據(jù)test_data評估網(wǎng)絡(luò)base_net鞋仍;

具體實現(xiàn):

for epoch in range(10):
    curr_loss = 0.0
    for i, (data, _) in enumerate(train_data):
        data = data.as_in_context(ctx)
        anc_ins, pos_ins, neg_ins = data[:, 0], data[:, 1], data[:, 2]
        with autograd.record():
            inter1 = base_net(anc_ins)
            inter2 = base_net(pos_ins)
            inter3 = base_net(neg_ins)
            loss = triplet_loss(inter1, inter2, inter3)  # Triplet Loss
        loss.backward()
        trainer_triplet.step(batch_size)
        curr_loss = mx.nd.mean(loss).asscalar()
        # print('Epoch: %s, Batch: %s, Triplet Loss: %s' % (epoch, i, curr_loss))
    print('Epoch: %s, Triplet Loss: %s' % (epoch, curr_loss))
    evaluate_net(base_net, test_data, ctx=ctx)  # 評估網(wǎng)絡(luò)

評估網(wǎng)絡(luò)也是一個重要的過程常摧,驗證網(wǎng)絡(luò)的泛化能力:

  1. 設(shè)置triplet_loss損失函數(shù),margin設(shè)置為0威创;
  2. test_data迭代輸出每個批次的驗證數(shù)據(jù)data落午;
  3. 指定驗證數(shù)據(jù)的環(huán)境,需要與訓(xùn)練一致肚豺,因為是在訓(xùn)練的過程中驗證板甘;
  4. 通過模型,預(yù)測三元數(shù)據(jù)详炬,計算損失函數(shù)盐类;
  5. 由于TripletLoss的margin是0寞奸,因此只有0才是預(yù)測正確,其余全部預(yù)測錯誤在跳;
  6. 統(tǒng)計整體的樣本總數(shù)和正確樣本數(shù)枪萄,計算全部測試數(shù)據(jù)的正確率;

具體實現(xiàn):

def evaluate_net(model, test_data, ctx):
    triplet_loss = gluon.loss.TripletLoss(margin=0)
    sum_correct = 0
    sum_all = 0
    rate = 0.0
    for i, (data, _) in enumerate(test_data):
        data = data.as_in_context(ctx)

        anc_ins, pos_ins, neg_ins = data[:, 0], data[:, 1], data[:, 2]
        inter1 = model(anc_ins)  # 訓(xùn)練的時候組合
        inter2 = model(pos_ins)
        inter3 = model(neg_ins)
        loss = triplet_loss(inter1, inter2, inter3)  

        loss = loss.asnumpy()
        n_all = loss.shape[0]
        n_correct = np.sum(np.where(loss == 0, 1, 0))

        sum_correct += n_correct
        sum_all += n_all
        rate = safe_div(sum_correct, sum_all)

    print('準(zhǔn)確率: %.4f (%s / %s)' % (rate, sum_correct, sum_all))
    return rate

在實驗輸出的效果中猫妙,Loss值逐漸減少瓷翻,驗證準(zhǔn)確率逐步上升,模型收斂效果較好割坠。具體如下:

Epoch: 0, Triplet Loss: 0.26367417
準(zhǔn)確率: 0.9052 (8065 / 8910)
Epoch: 1, Triplet Loss: 0.18126598
準(zhǔn)確率: 0.9297 (8284 / 8910)
Epoch: 2, Triplet Loss: 0.15365836
準(zhǔn)確率: 0.9391 (8367 / 8910)
Epoch: 3, Triplet Loss: 0.13773362
準(zhǔn)確率: 0.9448 (8418 / 8910)
Epoch: 4, Triplet Loss: 0.12188278
準(zhǔn)確率: 0.9495 (8460 / 8910)
Epoch: 5, Triplet Loss: 0.115614936
準(zhǔn)確率: 0.9520 (8482 / 8910)
Epoch: 6, Triplet Loss: 0.10390957
準(zhǔn)確率: 0.9544 (8504 / 8910)
Epoch: 7, Triplet Loss: 0.087059245
準(zhǔn)確率: 0.9569 (8526 / 8910)
Epoch: 8, Triplet Loss: 0.10168926
準(zhǔn)確率: 0.9588 (8543 / 8910)
Epoch: 9, Triplet Loss: 0.06260935
準(zhǔn)確率: 0.9606 (8559 / 8910)

可視化

Triplet Loss的核心功能就是將數(shù)據(jù)編碼為具有可區(qū)分性的特征齐帚。使用PCA降維,將樣本特征轉(zhuǎn)換為可視化的二維分布彼哼,通過觀察可知对妄,樣本特征具有一定的區(qū)分性。效果如下:

PCA-Triplet

而原始的數(shù)據(jù)分布敢朱,效果較差:

PCA-Origin

在訓(xùn)練結(jié)束時剪菱,執(zhí)行可視化數(shù)據(jù):

  • 原始的數(shù)據(jù)和標(biāo)簽
  • Triplet Loss網(wǎng)絡(luò)輸出的數(shù)據(jù)和標(biāo)簽

具體實現(xiàn):

te_data, te_label = transform(te_data, te_label)
tb_projector(te_data, te_label, os.path.join(ROOT_DIR, 'logs', 'origin'))
te_res = base_net(te_data)
tb_projector(te_res.asnumpy(), te_label, os.path.join(ROOT_DIR, 'logs', 'triplet'))

可視化工具以tensorboard為基礎(chǔ),通過嵌入向量的可視化接口實現(xiàn)數(shù)據(jù)分布的可視化拴签。在tb_projector()方法中孝常,輸入數(shù)據(jù)、標(biāo)簽和路徑蚓哩,即可生成可視化的數(shù)據(jù)格式构灸。

具體實現(xiàn):

def tb_projector(X_test, y_test, log_dir):
    metadata = os.path.join(log_dir, 'metadata.tsv')
    images = tf.Variable(X_test)
    with open(metadata, 'w') as metadata_file: # 把標(biāo)簽寫入metadata
        for row in y_test:
            metadata_file.write('%d\n' % row)
    with tf.Session() as sess:
        saver = tf.train.Saver([images])  # 把數(shù)據(jù)存儲為矩陣
        sess.run(images.initializer)  # 圖像初始化
        saver.save(sess, os.path.join(log_dir, 'images.ckpt'))  # 圖像存儲
        config = projector.ProjectorConfig()  # 配置
        embedding = config.embeddings.add()  # 嵌入向量添加
        embedding.tensor_name = images.name  # Tensor名稱
        embedding.metadata_path = metadata  # Metadata的路徑
        projector.visualize_embeddings(tf.summary.FileWriter(log_dir), config)  # 可視化嵌入向量

TensorBoard在可視化方面的功能較多,一些其他框架也是使用TensorBoard進行數(shù)據(jù)可視化岸梨,如tensorboard-pytorch等冻押,可視化為深度學(xué)習(xí)理論提供驗證。

TensorBoard需要額外安裝TensorFlow:

pip install tensorflow

Triplet Loss在數(shù)據(jù)編碼領(lǐng)域中盛嘿,有著重要的作用洛巢,算法也非常巧妙,適合相似性推薦等需求次兆,是重要的工業(yè)界需求之一稿茉,如推薦菜譜、推薦音樂芥炭、推薦視頻等漓库。Triplet Loss模型可以學(xué)習(xí)到數(shù)據(jù)集中不同樣本的相似性。除了傳統(tǒng)的Triplet Loss損失計算方法园蝠,還有一些有趣的優(yōu)化渺蒿,如Lossless Triplet Loss等。

OK, that's all! Enjoy it!

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末彪薛,一起剝皮案震驚了整個濱河市茂装,隨后出現(xiàn)的幾起案子怠蹂,更是在濱河造成了極大的恐慌,老刑警劉巖少态,帶你破解...
    沈念sama閱讀 217,406評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件城侧,死亡現(xiàn)場離奇詭異,居然都是意外死亡彼妻,警方通過查閱死者的電腦和手機嫌佑,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,732評論 3 393
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來侨歉,“玉大人屋摇,你說我怎么就攤上這事∮牡耍” “怎么了炮温?”我有些...
    開封第一講書人閱讀 163,711評論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長颊艳。 經(jīng)常有香客問我茅特,道長忘分,這世上最難降的妖魔是什么棋枕? 我笑而不...
    開封第一講書人閱讀 58,380評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮妒峦,結(jié)果婚禮上重斑,老公的妹妹穿的比我還像新娘。我一直安慰自己肯骇,他們只是感情好窥浪,可當(dāng)我...
    茶點故事閱讀 67,432評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著笛丙,像睡著了一般漾脂。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上胚鸯,一...
    開封第一講書人閱讀 51,301評論 1 301
  • 那天骨稿,我揣著相機與錄音,去河邊找鬼姜钳。 笑死坦冠,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的哥桥。 我是一名探鬼主播辙浑,決...
    沈念sama閱讀 40,145評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼拟糕!你這毒婦竟也來了判呕?” 一聲冷哼從身側(cè)響起倦踢,我...
    開封第一講書人閱讀 39,008評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎佛玄,沒想到半個月后硼一,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,443評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡梦抢,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,649評論 3 334
  • 正文 我和宋清朗相戀三年般贼,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片奥吩。...
    茶點故事閱讀 39,795評論 1 347
  • 序言:一個原本活蹦亂跳的男人離奇死亡哼蛆,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出霞赫,到底是詐尸還是另有隱情腮介,我是刑警寧澤,帶...
    沈念sama閱讀 35,501評論 5 345
  • 正文 年R本政府宣布端衰,位于F島的核電站叠洗,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏旅东。R本人自食惡果不足惜灭抑,卻給世界環(huán)境...
    茶點故事閱讀 41,119評論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望抵代。 院中可真熱鬧腾节,春花似錦、人聲如沸荤牍。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,731評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽康吵。三九已至劈榨,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間晦嵌,已是汗流浹背同辣。 一陣腳步聲響...
    開封第一講書人閱讀 32,865評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留耍铜,地道東北人邑闺。 一個月前我還...
    沈念sama閱讀 47,899評論 2 370
  • 正文 我出身青樓,卻偏偏與公主長得像棕兼,于是被迫代替她去往敵國和親陡舅。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,724評論 2 354